1 //===- Sparsification.cpp - Implementation of sparsification --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements converting sparse tensor types to actual sparse code. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "CodegenUtils.h" 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 17 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" 18 #include "mlir/Dialect/Linalg/IR/Linalg.h" 19 #include "mlir/Dialect/Linalg/Utils/Utils.h" 20 #include "mlir/Dialect/MemRef/IR/MemRef.h" 21 #include "mlir/Dialect/SCF/SCF.h" 22 #include "mlir/Dialect/SCF/Transforms.h" 23 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 24 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 25 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 26 #include "mlir/Dialect/StandardOps/IR/Ops.h" 27 #include "mlir/Dialect/Vector/VectorOps.h" 28 #include "mlir/IR/Matchers.h" 29 #include "mlir/IR/TensorEncoding.h" 30 #include "llvm/ADT/SmallBitVector.h" 31 32 using namespace mlir; 33 using namespace mlir::sparse_tensor; 34 35 //===----------------------------------------------------------------------===// 36 // Declarations of data structures. 37 //===----------------------------------------------------------------------===// 38 39 namespace { 40 41 // Iteration graph sorting. 42 enum SortMask { kSparseOnly = 0x0, kIncludeDense = 0x1, kIncludeUndef = 0x2 }; 43 44 // Reduction kinds. 45 enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor }; 46 47 // Code generation. 48 struct CodeGen { 49 CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops, 50 OpOperand *op, unsigned nest) 51 : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors), 52 pointers(numTensors, std::vector<Value>(numLoops)), 53 indices(numTensors, std::vector<Value>(numLoops)), 54 highs(numTensors, std::vector<Value>(numLoops)), 55 pidxs(numTensors, std::vector<Value>(numLoops)), 56 idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(), 57 redKind(kNoReduc), sparseOut(op), outerParNest(nest), lexIdx(), 58 expValues(), expFilled(), expAdded(), expCount(), curVecLength(1), 59 curVecMask() {} 60 /// Sparsification options. 61 SparsificationOptions options; 62 /// Universal dense indices and upper bounds (by index). The loops array 63 /// is updated with the value of the universal dense index in the current 64 /// loop. The sizes array is set once with the inferred dimension sizes. 65 std::vector<Value> loops; 66 std::vector<Value> sizes; 67 /// Buffers for storing dense and sparse numerical values (by tensor). 68 /// This array is set once during bufferization of all tensors. 69 std::vector<Value> buffers; 70 /// Sparse storage schemes (1-D): pointers and indices (by tensor and index). 71 /// This array is set once during bufferization of all sparse tensors. 72 std::vector<std::vector<Value>> pointers; 73 std::vector<std::vector<Value>> indices; 74 /// Sparse iteration information (by tensor and index). These arrays 75 /// are updated to remain current within the current loop. 76 std::vector<std::vector<Value>> highs; 77 std::vector<std::vector<Value>> pidxs; 78 std::vector<std::vector<Value>> idxs; 79 /// Current reduction, updated during code generation. When indices of a 80 /// reduction are exhausted, all inner loops can use a scalarized reduction. 81 unsigned redExp; 82 Value redVal; 83 Reduction redKind; 84 // Sparse tensor as output. Implemented either through direct injective 85 // insertion in lexicographic index order (where indices are updated 86 // in the temporary array `lexIdx`) or through access pattern expansion 87 // in the innermost loop nest (`expValues` through `expCount`). 88 OpOperand *sparseOut; 89 unsigned outerParNest; 90 Value lexIdx; 91 Value expValues; 92 Value expFilled; 93 Value expAdded; 94 Value expCount; 95 // Current vector length and mask. 96 unsigned curVecLength; 97 Value curVecMask; 98 }; 99 100 } // namespace 101 102 //===----------------------------------------------------------------------===// 103 // Sparse compiler analysis methods. 104 //===----------------------------------------------------------------------===// 105 106 /// Helper method to apply dimension ordering permutation. 107 static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d) { 108 if (enc) { 109 auto order = enc.getDimOrdering(); 110 if (order) { 111 assert(order.isPermutation()); 112 return order.getDimPosition(d); 113 } 114 } 115 return d; 116 } 117 118 /// Helper method to translate dim level type to internal representation. 119 static Dim toDim(const SparseTensorEncodingAttr &enc, unsigned d) { 120 if (enc) { 121 SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d]; 122 if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed) 123 return Dim::kSparse; 124 if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton) 125 return Dim::kSingle; 126 } 127 return Dim::kDense; 128 } 129 130 /// Helper method to inspect affine expressions. Rejects cases where the 131 /// same index is used more than once. Also rejects affine expressions 132 /// that are not a direct index for annotated tensors. 133 // TODO: accept more affine cases for sparse tensors 134 static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a, Dim dim, 135 bool isDense) { 136 switch (a.getKind()) { 137 case AffineExprKind::DimId: { 138 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 139 if (!merger.isDim(tensor, idx, Dim::kUndef)) 140 return false; // used more than once 141 merger.setDim(tensor, idx, dim); 142 return true; 143 } 144 case AffineExprKind::Add: 145 case AffineExprKind::Mul: { 146 if (!isDense) 147 return false; 148 auto binOp = a.cast<AffineBinaryOpExpr>(); 149 return findAffine(merger, tensor, binOp.getLHS(), dim, isDense) && 150 findAffine(merger, tensor, binOp.getRHS(), dim, isDense); 151 } 152 case AffineExprKind::Constant: 153 return isDense; 154 default: 155 return false; 156 } 157 } 158 159 /// Helper method to inspect sparse encodings in the tensor types. 160 /// Fills the per-dimension sparsity information for all tensors. 161 /// Returns true if the sparse annotations and affine subscript 162 /// expressions of all tensors are admissable. Returns false if 163 /// no annotations are found or inadmissable constructs occur. 164 static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) { 165 bool annotated = false; 166 for (OpOperand *t : op.getInputAndOutputOperands()) { 167 auto map = op.getTiedIndexingMap(t); 168 auto enc = getSparseTensorEncoding(t->get().getType()); 169 if (enc) 170 annotated = true; 171 assert(map.getNumResults() == op.getRank(t)); 172 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 173 unsigned tensor = t->getOperandNumber(); 174 AffineExpr a = map.getResult(perm(enc, d)); 175 if (!findAffine(merger, tensor, a, toDim(enc, d), !enc)) 176 return false; // inadmissable affine expression 177 } 178 } 179 return annotated; 180 } 181 182 /// A DFS helper to compute a topological sort. Note that recursion is 183 /// bounded by the number of implicit loops, which is always small. 184 /// Returns false when a cycle is detected. 185 static bool topSortDFS(unsigned i, std::vector<unsigned> &visit, 186 std::vector<unsigned> &topSort, 187 std::vector<std::vector<bool>> &adjM) { 188 if (visit[i] != 0) 189 return visit[i] != 1; // 1 denotes cycle! 190 visit[i] = 1; 191 for (unsigned j = 0, e = visit.size(); j < e; j++) 192 if (adjM[i][j]) 193 if (!topSortDFS(j, visit, topSort, adjM)) 194 return false; 195 visit[i] = 2; 196 topSort.push_back(i); 197 return true; 198 } 199 200 /// Helper method to add all constraints from the indices in one affine 201 /// expression before all indices in the other affine expression. For 202 /// example i0+i1 < i2+i3+1 yields i0<i2, i0<i3, i1<i2, and i1<i3. 203 static void addAffineOrderings(std::vector<std::vector<bool>> &adjM, 204 AffineExpr a, AffineExpr b, unsigned fidx) { 205 switch (a.getKind()) { 206 case AffineExprKind::DimId: { 207 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 208 if (b) 209 addAffineOrderings(adjM, b, AffineExpr(), idx); 210 else 211 adjM[fidx][idx] = true; 212 break; 213 } 214 case AffineExprKind::Add: 215 case AffineExprKind::Mul: { 216 auto binOp = a.cast<AffineBinaryOpExpr>(); 217 addAffineOrderings(adjM, binOp.getLHS(), b, fidx); 218 addAffineOrderings(adjM, binOp.getRHS(), b, fidx); 219 break; 220 } 221 default: 222 break; 223 } 224 } 225 226 /// Computes a topologically sorted iteration graph for the linalg operation. 227 /// Ensures all tensors are visited in natural index order. This is essential 228 /// for sparse storage formats since these only support access along fixed 229 /// dimensions. Even for dense storage formats, however, the natural index 230 /// order yields innermost unit-stride access with better spatial locality. 231 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op, 232 std::vector<unsigned> &topSort, 233 unsigned mask) { 234 // Set up an n x n from/to adjacency matrix of the iteration graph 235 // for the implicit loop indices i_0 .. i_n-1. 236 unsigned n = op.getNumLoops(); 237 std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false)); 238 239 // Iterate over the indexing maps of every tensor in the tensor expression. 240 for (OpOperand *t : op.getInputAndOutputOperands()) { 241 auto map = op.getTiedIndexingMap(t); 242 auto enc = getSparseTensorEncoding(t->get().getType()); 243 assert(map.getNumDims() == n); 244 // Skip dense tensor constraints when not requested. 245 if (!(mask & SortMask::kIncludeDense) && !enc) 246 continue; 247 // Each tensor expression and optional dimension ordering (row-major 248 // by default) puts an ordering constraint on the loop indices. For 249 // example, the tensor expresion A_ijk forces the ordering i < j < k 250 // on the loop indices if no explicit dimension ordering is given. 251 for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) { 252 AffineExpr f = map.getResult(perm(enc, d - 1)); 253 AffineExpr t = map.getResult(perm(enc, d)); 254 addAffineOrderings(adjM, f, t, 0); 255 } 256 // Push unrelated loops into sparse iteration space, so these 257 // will be skipped more often. 258 if (mask & SortMask::kIncludeUndef) { 259 unsigned tensor = t->getOperandNumber(); 260 for (unsigned i = 0; i < n; i++) 261 if (merger.isDim(tensor, i, Dim::kSparse)) 262 for (unsigned j = 0; j < n; j++) 263 if (merger.isDim(tensor, j, Dim::kUndef)) 264 adjM[i][j] = true; 265 } 266 } 267 268 // Topologically sort the iteration graph to determine loop order. 269 // Report failure for a cyclic iteration graph. 270 topSort.clear(); 271 topSort.reserve(n); 272 std::vector<unsigned> visit(n, 0); 273 for (unsigned i = 0; i < n; i++) 274 if (visit[i] == 0) 275 if (!topSortDFS(i, visit, topSort, adjM)) 276 return false; // cycle! 277 std::reverse(std::begin(topSort), std::end(topSort)); 278 return true; 279 } 280 281 /// Returns true if tensor has an in-place annotation. 282 static bool isInPlace(Value val) { 283 if (auto arg = val.dyn_cast<BlockArgument>()) 284 if (auto funcOp = dyn_cast<FuncOp>(arg.getOwner()->getParentOp())) 285 if (auto attr = funcOp.getArgAttrOfType<BoolAttr>( 286 arg.getArgNumber(), 287 linalg::comprehensive_bufferize::BufferizableOpInterface:: 288 kInplaceableAttrName)) 289 return attr.getValue(); 290 return false; 291 } 292 293 /// Returns true if tensor materializes uninitialized into the computation. 294 static bool isMaterializing(Value val) { 295 return val.getDefiningOp<linalg::InitTensorOp>() || 296 val.getDefiningOp<InitOp>(); 297 } 298 299 /// Returns true when the tensor expression is admissable for codegen. 300 /// Since all sparse input tensors are admissable, we just need to check 301 /// whether the out tensor in the tensor expression codegen is admissable. 302 /// Sets `sparseOut` to the tensor and `outerParNest` to the outer injective 303 /// nesting depth when a "truly dynamic" sparse tensor output occurs. 304 static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op, 305 std::vector<unsigned> &topSort, unsigned exp, 306 OpOperand **sparseOut, 307 unsigned &outerParNest) { 308 OpOperand *lhs = op.getOutputOperand(0); 309 unsigned tensor = lhs->getOperandNumber(); 310 auto enc = getSparseTensorEncoding(lhs->get().getType()); 311 // An non-annotated output tensor is assumed dense, and becomes a random 312 // access n-dim memref. Admissable since insertions cannot occur. 313 if (!enc) 314 return true; 315 // An all-dense annotated "sparse" output tensor becomes a linearized random 316 // access 1-dim memref. Also admissable since insertions cannot occur. 317 bool allDense = true; 318 auto iteratorTypes = op.iterator_types().getValue(); 319 unsigned numLoops = iteratorTypes.size(); 320 for (unsigned i = 0; i < numLoops; i++) 321 if (merger.isDim(tensor, i, Dim::kSparse)) { 322 allDense = false; 323 break; 324 } 325 if (allDense) 326 return true; 327 // A tensor expression with a sparse output tensor that changes its values 328 // but not its nonzero structure, an operation called "simply dynamic" in 329 // [Bik96,Ch9], is also admissable without special codegen, provided 330 // the tensor's underlying sparse storage scheme can be modified in place. 331 if (merger.isSingleCondition(tensor, exp) && isInPlace(lhs->get())) 332 return true; 333 // Accept "truly dynamic" if the output tensor materializes uninitialized 334 // into the computation and insertions occur in lexicographic index order. 335 if (isMaterializing(lhs->get())) { 336 unsigned nest = 0; 337 for (unsigned i = 0; i < numLoops; i++) { 338 if (isReductionIterator(iteratorTypes[topSort[i]])) 339 break; // terminate at first reduction 340 nest++; 341 } 342 // Determine admissable dynamic insertion situations: 343 // (1) fully injective, since there are no reductions, 344 // (2) admissable 1-d expansion in innermost dimension. 345 if (nest >= op.getRank(lhs) - 1) { 346 *sparseOut = lhs; 347 outerParNest = nest; 348 return true; 349 } 350 } 351 return false; 352 } 353 354 //===----------------------------------------------------------------------===// 355 // Sparse compiler synthesis methods (reductions). 356 //===----------------------------------------------------------------------===// 357 358 /// Maps reduction kind to name encoding. 359 static StringRef getReductionName(Reduction kind) { 360 switch (kind) { 361 case kNoReduc: 362 break; 363 case kSum: 364 return "add"; 365 case kProduct: 366 return "mul"; 367 case kAnd: 368 return "and"; 369 case kOr: 370 return "or"; 371 case kXor: 372 return "xor"; 373 } 374 llvm_unreachable("unknown reduction kind"); 375 } 376 377 /// Maps operation to reduction. 378 static Reduction getReduction(Kind kind) { 379 switch (kind) { 380 case Kind::kAddF: 381 case Kind::kAddI: 382 case Kind::kSubF: 383 case Kind::kSubI: 384 return kSum; 385 case Kind::kMulF: 386 case Kind::kMulI: 387 return kProduct; 388 case Kind::kAndI: 389 return kAnd; 390 case Kind::kOrI: 391 return kOr; 392 case Kind::kXorI: 393 return kXor; 394 default: 395 llvm_unreachable("unexpected reduction operator"); 396 } 397 } 398 399 /// Generates an initial value for a vector reduction, following the scheme 400 /// given in Chapter 5 of "The Software Vectorization Handbook", where the 401 /// initial scalar value is correctly embedded in the vector reduction value, 402 /// and a straightforward horizontal reduction will complete the operation. 403 static Value genVectorReducInit(CodeGen &codegen, PatternRewriter &rewriter, 404 Location loc, VectorType vtp) { 405 Value r = codegen.redVal; 406 switch (codegen.redKind) { 407 case kNoReduc: 408 break; 409 case kSum: 410 case kXor: 411 // Initialize reduction vector to: | 0 | .. | 0 | r | 412 return rewriter.create<vector::InsertElementOp>( 413 loc, r, constantZero(rewriter, loc, vtp), 414 constantIndex(rewriter, loc, 0)); 415 case kProduct: 416 // Initialize reduction vector to: | 1 | .. | 1 | r | 417 return rewriter.create<vector::InsertElementOp>( 418 loc, r, constantOne(rewriter, loc, vtp), 419 constantIndex(rewriter, loc, 0)); 420 case kAnd: 421 case kOr: 422 // Initialize reduction vector to: | r | .. | r | r | 423 return rewriter.create<vector::BroadcastOp>(loc, vtp, r); 424 } 425 llvm_unreachable("unknown reduction kind"); 426 } 427 428 /// Generates final value for a vector reduction. 429 static Value genVectorReducEnd(CodeGen &codegen, PatternRewriter &rewriter, 430 Location loc, VectorType vtp) { 431 StringRef name = getReductionName(codegen.redKind); 432 StringAttr kind = rewriter.getStringAttr(name); 433 return rewriter.create<vector::ReductionOp>(loc, vtp.getElementType(), kind, 434 codegen.redVal, ValueRange{}); 435 } 436 437 /// Updates scalarized reduction value. 438 static void updateReduc(Merger &merger, CodeGen &codegen, Value reduc) { 439 assert(codegen.redKind != kNoReduc); 440 codegen.redVal = merger.exp(codegen.redExp).val = reduc; 441 } 442 443 //===----------------------------------------------------------------------===// 444 // Sparse compiler synthesis methods (statements and expressions). 445 //===----------------------------------------------------------------------===// 446 447 /// Generates buffer for the output tensor. Note that all sparse kernels 448 /// assume that when all elements are written to (viz. x(i) = y(i) * z(i)), 449 /// the output buffer is already initialized to all zeroes and only nonzeroes 450 /// values are computed and written out. For updates (viz. x(i) += y(i) * z(i)), 451 /// only nonzeroes values are used for the updates and no assumption on the 452 /// original contents of the output buffer is necessary.. 453 static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter, 454 linalg::GenericOp op, MemRefType denseTp, 455 ArrayRef<Value> args) { 456 Location loc = op.getLoc(); 457 Value tensor = op.getOutputOperand(0)->get(); 458 // The output tensor simply could materialize from the buffer that will 459 // be generated for the tensor present in the outs() clause. This has 460 // the major advantage that the sparse kernel only updates the nonzero 461 // positions for the output tensor. 462 if (isInPlace(tensor)) 463 return rewriter.create<bufferization::ToMemrefOp>(loc, denseTp, tensor); 464 // By default, a new buffer is allocated which is initialized to the 465 // tensor defined in the outs() clause. This is always correct but 466 // introduces a dense initialization component that may negatively 467 // impact the running complexity of the sparse kernel. If the tensor 468 // materializes into the computation, we need to preserve the zero 469 // initialization assumption of all sparse output buffers. 470 if (isMaterializing(tensor)) { 471 Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args); 472 Value zero = constantZero(rewriter, loc, denseTp.getElementType()); 473 rewriter.create<linalg::FillOp>(loc, zero, alloc); 474 return alloc; 475 } 476 Value init = rewriter.create<bufferization::ToMemrefOp>(loc, denseTp, tensor); 477 Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args); 478 rewriter.create<memref::CopyOp>(loc, init, alloc); 479 return alloc; 480 } 481 482 /// Local bufferization of all dense and sparse data structures. 483 /// This code enables testing the first prototype sparse compiler. 484 // TODO: replace this with a proliferated bufferization strategy 485 static void genBuffers(Merger &merger, CodeGen &codegen, 486 PatternRewriter &rewriter, linalg::GenericOp op) { 487 Location loc = op.getLoc(); 488 assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1); 489 // For every tensor, find lower and upper bound on dimensions, set the 490 // same bounds on loop indices, and obtain dense or sparse buffer(s). 491 SmallVector<Value, 4> args; 492 for (OpOperand *t : op.getInputAndOutputOperands()) { 493 unsigned tensor = t->getOperandNumber(); 494 auto shape = op.getShape(t); 495 auto map = op.getTiedIndexingMap(t); 496 auto enc = getSparseTensorEncoding(t->get().getType()); 497 // Scan all dimensions of current tensor. 498 args.clear(); 499 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 500 AffineExpr a = map.getResult(perm(enc, d)); 501 if (a.getKind() != AffineExprKind::DimId) 502 continue; // compound 503 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 504 // Handle sparse storage schemes. 505 if (merger.isDim(tensor, idx, Dim::kSparse)) { 506 auto dynShape = {ShapedType::kDynamicSize}; 507 auto ptrTp = 508 MemRefType::get(dynShape, getPointerOverheadType(rewriter, enc)); 509 auto indTp = 510 MemRefType::get(dynShape, getIndexOverheadType(rewriter, enc)); 511 Value dim = constantIndex(rewriter, loc, d); 512 // Generate sparse primitives to obtains pointer and indices. 513 codegen.pointers[tensor][idx] = 514 rewriter.create<ToPointersOp>(loc, ptrTp, t->get(), dim); 515 codegen.indices[tensor][idx] = 516 rewriter.create<ToIndicesOp>(loc, indTp, t->get(), dim); 517 } 518 // Find upper bound in current dimension. 519 unsigned p = perm(enc, d); 520 Value up = linalg::createOrFoldDimOp(rewriter, loc, t->get(), p); 521 if (ShapedType::isDynamic(shape[p])) 522 args.push_back(up); 523 assert(codegen.highs[tensor][idx] == nullptr); 524 codegen.sizes[idx] = codegen.highs[tensor][idx] = up; 525 } 526 // Perform the required bufferization. Dense inputs materialize 527 // from the input tensors. Dense outputs need special handling. 528 // Sparse inputs use sparse primitives to obtain the values. 529 // We also accept in-place all-dense annotated "sparse" outputs. 530 Type elementType = getElementTypeOrSelf(t->get().getType()); 531 if (!enc) { 532 // Non-annotated dense tensors. 533 auto denseTp = MemRefType::get(shape, elementType); 534 if (tensor < op.getNumInputs()) 535 codegen.buffers[tensor] = 536 rewriter.create<bufferization::ToMemrefOp>(loc, denseTp, t->get()); 537 else 538 codegen.buffers[tensor] = 539 genOutputBuffer(codegen, rewriter, op, denseTp, args); 540 } else if (t == codegen.sparseOut) { 541 // True sparse output needs a lexIdx array. 542 Value rank = constantIndex(rewriter, loc, op.getRank(t)); 543 auto dynShape = {ShapedType::kDynamicSize}; 544 auto memTp = MemRefType::get(dynShape, rewriter.getIndexType()); 545 codegen.lexIdx = rewriter.create<memref::AllocaOp>(loc, memTp, rank); 546 } else { 547 // Annotated sparse tensors. 548 auto dynShape = {ShapedType::kDynamicSize}; 549 auto sparseTp = MemRefType::get(dynShape, elementType); 550 codegen.buffers[tensor] = 551 rewriter.create<ToValuesOp>(loc, sparseTp, t->get()); 552 } 553 } 554 } 555 556 /// Constructs vector type. 557 static VectorType vectorType(CodeGen &codegen, Type etp) { 558 return VectorType::get(codegen.curVecLength, etp); 559 } 560 561 /// Constructs vector type from pointer. 562 static VectorType vectorType(CodeGen &codegen, Value ptr) { 563 return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType()); 564 } 565 566 /// Constructs vector iteration mask. 567 static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter, 568 Value iv, Value lo, Value hi, Value step) { 569 Location loc = iv.getLoc(); 570 VectorType mtp = vectorType(codegen, rewriter.getI1Type()); 571 // Special case if the vector length evenly divides the trip count (for 572 // example, "for i = 0, 128, 16"). A constant all-true mask is generated 573 // so that all subsequent masked memory operations are immediately folded 574 // into unconditional memory operations. 575 IntegerAttr loInt, hiInt, stepInt; 576 if (matchPattern(lo, m_Constant(&loInt)) && 577 matchPattern(hi, m_Constant(&hiInt)) && 578 matchPattern(step, m_Constant(&stepInt))) { 579 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) 580 return rewriter.create<vector::BroadcastOp>( 581 loc, mtp, constantI1(rewriter, loc, true)); 582 } 583 // Otherwise, generate a vector mask that avoids overrunning the upperbound 584 // during vector execution. Here we rely on subsequent loop optimizations to 585 // avoid executing the mask in all iterations, for example, by splitting the 586 // loop into an unconditional vector loop and a scalar cleanup loop. 587 auto minMap = AffineMap::get( 588 /*dimCount=*/2, /*symbolCount=*/1, 589 {rewriter.getAffineSymbolExpr(0), 590 rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)}, 591 rewriter.getContext()); 592 Value end = 593 rewriter.createOrFold<AffineMinOp>(loc, minMap, ValueRange{hi, iv, step}); 594 return rewriter.create<vector::CreateMaskOp>(loc, mtp, end); 595 } 596 597 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi]. 598 static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter, 599 Value ptr, ArrayRef<Value> args) { 600 Location loc = ptr.getLoc(); 601 VectorType vtp = vectorType(codegen, ptr); 602 Value pass = constantZero(rewriter, loc, vtp); 603 if (args.back().getType().isa<VectorType>()) { 604 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 605 Value indexVec = args.back(); 606 scalarArgs.back() = constantIndex(rewriter, loc, 0); 607 return rewriter.create<vector::GatherOp>( 608 loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass); 609 } 610 return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args, 611 codegen.curVecMask, pass); 612 } 613 614 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs. 615 static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter, 616 Value rhs, Value ptr, ArrayRef<Value> args) { 617 Location loc = ptr.getLoc(); 618 if (args.back().getType().isa<VectorType>()) { 619 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 620 Value indexVec = args.back(); 621 scalarArgs.back() = constantIndex(rewriter, loc, 0); 622 rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, 623 codegen.curVecMask, rhs); 624 return; 625 } 626 rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask, 627 rhs); 628 } 629 630 /// Generates a vectorized invariant. Here we rely on subsequent loop 631 /// optimizations to hoist the invariant broadcast out of the vector loop. 632 static Value genVectorInvariantValue(CodeGen &codegen, 633 PatternRewriter &rewriter, Value val) { 634 VectorType vtp = vectorType(codegen, val.getType()); 635 return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val); 636 } 637 638 /// Generates an affine expression. 639 // 640 // TODO: generalize for sparse tensor subscripts 641 // 642 static Value genAffine(CodeGen &codegen, PatternRewriter &rewriter, 643 AffineExpr a, Location loc) { 644 switch (a.getKind()) { 645 case AffineExprKind::DimId: { 646 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 647 return codegen.loops[idx]; // universal dense index 648 } 649 case AffineExprKind::Add: { 650 auto binOp = a.cast<AffineBinaryOpExpr>(); 651 return rewriter.create<arith::AddIOp>( 652 loc, genAffine(codegen, rewriter, binOp.getLHS(), loc), 653 genAffine(codegen, rewriter, binOp.getRHS(), loc)); 654 } 655 case AffineExprKind::Mul: { 656 auto binOp = a.cast<AffineBinaryOpExpr>(); 657 return rewriter.create<arith::MulIOp>( 658 loc, genAffine(codegen, rewriter, binOp.getLHS(), loc), 659 genAffine(codegen, rewriter, binOp.getRHS(), loc)); 660 } 661 case AffineExprKind::Constant: { 662 int64_t c = a.cast<AffineConstantExpr>().getValue(); 663 return constantIndex(rewriter, loc, c); 664 } 665 default: 666 llvm_unreachable("unexpected affine subscript"); 667 } 668 } 669 670 /// Generates index for load/store on sparse tensor. 671 static Value genIndex(CodeGen &codegen, linalg::GenericOp op, OpOperand *t) { 672 auto map = op.getTiedIndexingMap(t); 673 auto enc = getSparseTensorEncoding(t->get().getType()); 674 AffineExpr a = map.getResult(perm(enc, map.getNumResults() - 1)); 675 assert(a.getKind() == AffineExprKind::DimId); 676 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 677 return codegen.loops[idx]; 678 } 679 680 /// Generates subscript for load/store on a dense or sparse tensor. 681 static Value genSubscript(CodeGen &codegen, PatternRewriter &rewriter, 682 linalg::GenericOp op, OpOperand *t, 683 SmallVector<Value, 4> &args) { 684 unsigned tensor = t->getOperandNumber(); 685 auto map = op.getTiedIndexingMap(t); 686 auto enc = getSparseTensorEncoding(t->get().getType()); 687 unsigned rank = map.getNumResults(); 688 if (enc) { 689 // Note that currently, all sparse subscripts are simple. 690 // TODO: accept affine too? 691 AffineExpr a = map.getResult(perm(enc, rank - 1)); 692 assert(a.getKind() == AffineExprKind::DimId); 693 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 694 assert(codegen.pidxs[tensor][idx] != nullptr); 695 args.push_back(codegen.pidxs[tensor][idx]); // position index 696 } else { 697 for (unsigned d = 0; d < rank; d++) { 698 AffineExpr a = map.getResult(perm(enc, d)); 699 args.push_back(genAffine(codegen, rewriter, a, op.getLoc())); 700 } 701 } 702 return codegen.buffers[tensor]; 703 } 704 705 /// Generates insertion code to implement dynamic tensor load. 706 static Value genInsertionLoad(CodeGen &codegen, PatternRewriter &rewriter, 707 linalg::GenericOp op, OpOperand *t) { 708 Location loc = op.getLoc(); 709 // Direct lexicographic index order, tensor loads as zero. 710 if (!codegen.expValues) { 711 Type tp = getElementTypeOrSelf(t->get().getType()); 712 return constantZero(rewriter, loc, tp); 713 } 714 // Load from expanded access pattern. 715 Value index = genIndex(codegen, op, t); 716 return rewriter.create<memref::LoadOp>(loc, codegen.expValues, index); 717 } 718 719 /// Generates insertion code to implement dynamic tensor store. 720 static void genInsertionStore(CodeGen &codegen, PatternRewriter &rewriter, 721 linalg::GenericOp op, OpOperand *t, Value rhs) { 722 Location loc = op.getLoc(); 723 // Direct insertion in lexicographic index order. 724 if (!codegen.expValues) { 725 rewriter.create<LexInsertOp>(loc, t->get(), codegen.lexIdx, rhs); 726 return; 727 } 728 // Generates insertion code along expanded access pattern. 729 // if (!expFilled[i]) then 730 // expFilled[i] = true 731 // expAdded[inserts++] = i 732 // endif 733 // values[i] = rhs 734 Value index = genIndex(codegen, op, t); 735 Value fval = constantI1(rewriter, loc, false); 736 Value tval = constantI1(rewriter, loc, true); 737 // If statement. 738 Value filled = rewriter.create<memref::LoadOp>(loc, codegen.expFilled, index); 739 Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 740 filled, fval); 741 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, rewriter.getIndexType(), 742 cond, /*else=*/true); 743 // True branch. 744 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); 745 rewriter.create<memref::StoreOp>(loc, tval, codegen.expFilled, index); 746 rewriter.create<memref::StoreOp>(loc, index, codegen.expAdded, 747 codegen.expCount); 748 Value one = constantIndex(rewriter, loc, 1); 749 Value add = rewriter.create<arith::AddIOp>(loc, codegen.expCount, one); 750 rewriter.create<scf::YieldOp>(loc, add); 751 // False branch. 752 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); 753 rewriter.create<scf::YieldOp>(loc, codegen.expCount); 754 rewriter.setInsertionPointAfter(ifOp); 755 // Value assignment. 756 codegen.expCount = ifOp.getResult(0); 757 rewriter.create<memref::StoreOp>(loc, rhs, codegen.expValues, index); 758 } 759 760 /// Generates a load on a dense or sparse tensor. 761 static Value genTensorLoad(Merger &merger, CodeGen &codegen, 762 PatternRewriter &rewriter, linalg::GenericOp op, 763 unsigned exp) { 764 // Test if the load was hoisted to a higher loop nest. 765 Value val = merger.exp(exp).val; 766 if (val) { 767 if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>()) 768 return genVectorInvariantValue(codegen, rewriter, val); 769 return val; 770 } 771 // Load during insertion. 772 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 773 if (t == codegen.sparseOut) 774 return genInsertionLoad(codegen, rewriter, op, t); 775 // Actual load. 776 SmallVector<Value, 4> args; 777 Value ptr = genSubscript(codegen, rewriter, op, t, args); 778 if (codegen.curVecLength > 1) 779 return genVectorLoad(codegen, rewriter, ptr, args); 780 return rewriter.create<memref::LoadOp>(op.getLoc(), ptr, args); 781 } 782 783 /// Generates a store on a dense or sparse tensor. 784 static void genTensorStore(Merger &merger, CodeGen &codegen, 785 PatternRewriter &rewriter, linalg::GenericOp op, 786 Value rhs) { 787 Location loc = op.getLoc(); 788 // Test if this is a scalarized reduction. 789 if (codegen.redVal) { 790 if (codegen.curVecLength > 1) 791 rhs = rewriter.create<SelectOp>(loc, codegen.curVecMask, rhs, 792 codegen.redVal); 793 updateReduc(merger, codegen, rhs); 794 return; 795 } 796 // Store during insertion. 797 OpOperand *t = op.getOutputOperand(0); 798 if (t == codegen.sparseOut) { 799 genInsertionStore(codegen, rewriter, op, t, rhs); 800 return; 801 } 802 // Actual store. 803 SmallVector<Value, 4> args; 804 Value ptr = genSubscript(codegen, rewriter, op, t, args); 805 if (codegen.curVecLength > 1) 806 genVectorStore(codegen, rewriter, rhs, ptr, args); 807 else 808 rewriter.create<memref::StoreOp>(loc, rhs, ptr, args); 809 } 810 811 /// Generates a pointer/index load from the sparse storage scheme. Narrower 812 /// data types need to be zero extended before casting the value into the 813 /// index type used for looping and indexing. 814 static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc, 815 Value ptr, Value s) { 816 // See https://llvm.org/docs/GetElementPtr.html for some background on 817 // the complications described below. 818 if (codegen.curVecLength > 1) { 819 // Since the index vector is used in a subsequent gather/scatter operations, 820 // which effectively defines an unsigned pointer + signed index, we must 821 // zero extend the vector to an index width. For 8-bit and 16-bit values, 822 // an 32-bit index width suffices. For 32-bit values, zero extending the 823 // elements into 64-bit loses some performance since the 32-bit indexed 824 // gather/scatter is more efficient than the 64-bit index variant (if the 825 // negative 32-bit index space is unused, the enableSIMDIndex32 flag can 826 // preserve this performance). For 64-bit values, there is no good way 827 // to state that the indices are unsigned, with creates the potential of 828 // incorrect address calculations in the unlikely case we need such 829 // extremely large offsets. 830 Type etp = ptr.getType().cast<MemRefType>().getElementType(); 831 Value vload = genVectorLoad(codegen, rewriter, ptr, {s}); 832 if (!etp.isa<IndexType>()) { 833 if (etp.getIntOrFloatBitWidth() < 32) 834 vload = rewriter.create<arith::ExtUIOp>( 835 loc, vload, vectorType(codegen, rewriter.getI32Type())); 836 else if (etp.getIntOrFloatBitWidth() < 64 && 837 !codegen.options.enableSIMDIndex32) 838 vload = rewriter.create<arith::ExtUIOp>( 839 loc, vload, vectorType(codegen, rewriter.getI64Type())); 840 } 841 return vload; 842 } 843 // For the scalar case, we simply zero extend narrower indices into 64-bit 844 // values before casting to index without a performance penalty. Here too, 845 // however, indices that already are 64-bit, in theory, cannot express the 846 // full range as explained above. 847 Value load = rewriter.create<memref::LoadOp>(loc, ptr, s); 848 if (!load.getType().isa<IndexType>()) { 849 if (load.getType().getIntOrFloatBitWidth() < 64) 850 load = rewriter.create<arith::ExtUIOp>(loc, load, rewriter.getI64Type()); 851 load = 852 rewriter.create<arith::IndexCastOp>(loc, load, rewriter.getIndexType()); 853 } 854 return load; 855 } 856 857 /// Generates an invariant value. 858 static Value genInvariantValue(Merger &merger, CodeGen &codegen, 859 PatternRewriter &rewriter, unsigned exp) { 860 Value val = merger.exp(exp).val; 861 if (codegen.curVecLength > 1) 862 return genVectorInvariantValue(codegen, rewriter, val); 863 return val; 864 } 865 866 /// Generates an address computation "sz * p + i". 867 static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter, 868 Location loc, Value size, Value p, Value i) { 869 Value mul = rewriter.create<arith::MulIOp>(loc, size, p); 870 if (auto vtp = i.getType().dyn_cast<VectorType>()) { 871 Value inv = 872 rewriter.create<arith::IndexCastOp>(loc, mul, vtp.getElementType()); 873 mul = genVectorInvariantValue(codegen, rewriter, inv); 874 } 875 return rewriter.create<arith::AddIOp>(loc, mul, i); 876 } 877 878 /// Recursively generates tensor expression. 879 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 880 linalg::GenericOp op, unsigned exp) { 881 Location loc = op.getLoc(); 882 if (exp == -1u) 883 return Value(); 884 if (merger.exp(exp).kind == Kind::kTensor) 885 return genTensorLoad(merger, codegen, rewriter, op, exp); 886 if (merger.exp(exp).kind == Kind::kInvariant) 887 return genInvariantValue(merger, codegen, rewriter, exp); 888 Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0); 889 Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1); 890 return merger.buildExp(rewriter, loc, exp, v0, v1); 891 } 892 893 /// Determines if affine expression is invariant. 894 static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a, 895 unsigned ldx, bool &atLevel) { 896 switch (a.getKind()) { 897 case AffineExprKind::DimId: { 898 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 899 if (idx == ldx) 900 atLevel = true; 901 return codegen.loops[idx] != nullptr; // no longer in play? 902 } 903 case AffineExprKind::Add: 904 case AffineExprKind::Mul: { 905 auto binOp = a.cast<AffineBinaryOpExpr>(); 906 return isInvariantAffine(codegen, binOp.getLHS(), ldx, atLevel) && 907 isInvariantAffine(codegen, binOp.getRHS(), ldx, atLevel); 908 } 909 default: 910 return true; 911 } 912 } 913 914 /// Hoists loop invariant tensor loads for which indices have been exhausted. 915 static void genInvariants(Merger &merger, CodeGen &codegen, 916 PatternRewriter &rewriter, linalg::GenericOp op, 917 unsigned exp, unsigned ldx, bool atStart, 918 Kind last = Kind::kTensor) { 919 if (exp == -1u) 920 return; 921 if (merger.exp(exp).kind == Kind::kTensor) { 922 // Inspect tensor indices. 923 bool atLevel = ldx == -1u; 924 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 925 auto map = op.getTiedIndexingMap(t); 926 auto enc = getSparseTensorEncoding(t->get().getType()); 927 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 928 AffineExpr a = map.getResult(perm(enc, d)); 929 if (!isInvariantAffine(codegen, a, ldx, atLevel)) 930 return; // still in play 931 } 932 // All exhausted at this level (atLevel denotes exactly at this level). 933 if (!atLevel) 934 return; 935 OpOperand *lhs = op.getOutputOperand(0); 936 if (lhs == t) { 937 // Start or end a scalarized reduction 938 if (atStart) { 939 Value load = genTensorLoad(merger, codegen, rewriter, op, exp); 940 codegen.redKind = getReduction(last); 941 codegen.redExp = exp; 942 updateReduc(merger, codegen, load); 943 } else { 944 Value redVal = codegen.redVal; 945 updateReduc(merger, codegen, Value()); 946 codegen.redExp = -1u; 947 codegen.redKind = kNoReduc; 948 genTensorStore(merger, codegen, rewriter, op, redVal); 949 } 950 } else { 951 // Start or end loop invariant hoisting of a tensor load. 952 merger.exp(exp).val = 953 atStart ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value(); 954 } 955 } else if (merger.exp(exp).kind != Kind::kInvariant) { 956 // Traverse into the binary operations. Note that we only hoist 957 // tensor loads, since subsequent MLIR/LLVM passes know how to 958 // deal with all other kinds of derived loop invariants. 959 Kind last = merger.exp(exp).kind; 960 unsigned e0 = merger.exp(exp).children.e0; 961 unsigned e1 = merger.exp(exp).children.e1; 962 genInvariants(merger, codegen, rewriter, op, e0, ldx, atStart, last); 963 genInvariants(merger, codegen, rewriter, op, e1, ldx, atStart, last); 964 } 965 } 966 967 /// Generates an expanded access pattern in innermost dimension. 968 static void genExpansion(Merger &merger, CodeGen &codegen, 969 PatternRewriter &rewriter, linalg::GenericOp op, 970 unsigned at, bool atStart) { 971 OpOperand *lhs = codegen.sparseOut; 972 if (!lhs || codegen.outerParNest != op.getRank(lhs) - 1 || 973 at != codegen.outerParNest) 974 return; // not needed at this level 975 // Generate start or end of an expanded access pattern. 976 Value tensor = lhs->get(); 977 Location loc = op.getLoc(); 978 if (atStart) { 979 auto dynShape = {ShapedType::kDynamicSize}; 980 Type etp = tensor.getType().cast<ShapedType>().getElementType(); 981 Type t1 = MemRefType::get(dynShape, etp); 982 Type t2 = MemRefType::get(dynShape, rewriter.getI1Type()); 983 Type t3 = MemRefType::get(dynShape, rewriter.getIndexType()); 984 Type t4 = rewriter.getIndexType(); 985 auto res = 986 rewriter.create<ExpandOp>(loc, TypeRange({t1, t2, t3, t4}), tensor); 987 assert(res.getNumResults() == 4); 988 assert(!codegen.expValues); 989 codegen.expValues = res.getResult(0); 990 codegen.expFilled = res.getResult(1); 991 codegen.expAdded = res.getResult(2); 992 codegen.expCount = res.getResult(3); 993 } else { 994 assert(codegen.expValues); 995 rewriter.create<CompressOp>(loc, tensor, codegen.lexIdx, codegen.expValues, 996 codegen.expFilled, codegen.expAdded, 997 codegen.expCount); 998 codegen.expValues = codegen.expFilled = codegen.expAdded = 999 codegen.expCount = Value(); 1000 } 1001 } 1002 1003 /// Generates initialization code for the subsequent loop sequence at 1004 /// current index level. Returns true if the loop sequence needs to 1005 /// maintain the universal index. 1006 static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1007 linalg::GenericOp op, std::vector<unsigned> &topSort, 1008 unsigned at, llvm::BitVector &inits) { 1009 bool needsUniv = false; 1010 Location loc = op.getLoc(); 1011 unsigned idx = topSort[at]; 1012 1013 // Initialize sparse positions. 1014 for (unsigned b = 0, be = inits.size(); b < be; b++) { 1015 if (inits[b]) { 1016 unsigned tensor = merger.tensor(b); 1017 assert(idx == merger.index(b)); 1018 if (merger.isDim(b, Dim::kSparse)) { 1019 // Initialize sparse index. 1020 unsigned pat = at; 1021 for (; pat != 0; pat--) { 1022 if (codegen.pidxs[tensor][topSort[pat - 1]]) 1023 break; 1024 } 1025 Value ptr = codegen.pointers[tensor][idx]; 1026 Value one = constantIndex(rewriter, loc, 1); 1027 Value p0 = (pat == 0) ? constantIndex(rewriter, loc, 0) 1028 : codegen.pidxs[tensor][topSort[pat - 1]]; 1029 codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0); 1030 Value p1 = rewriter.create<arith::AddIOp>(loc, p0, one); 1031 codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1); 1032 } else { 1033 // Dense index still in play. 1034 needsUniv = true; 1035 } 1036 } 1037 } 1038 1039 // Initialize the universal dense index. 1040 codegen.loops[idx] = constantIndex(rewriter, loc, 0); 1041 return needsUniv; 1042 } 1043 1044 /// Returns vectorization strategy. Any implicit inner loop in the Linalg 1045 /// operation is a candidate. Whether it is actually converted to SIMD code 1046 /// depends on the requested strategy. 1047 static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) { 1048 switch (codegen.options.vectorizationStrategy) { 1049 case SparseVectorizationStrategy::kNone: 1050 return false; 1051 case SparseVectorizationStrategy::kDenseInnerLoop: 1052 return isInner && !isSparse; 1053 case SparseVectorizationStrategy::kAnyStorageInnerLoop: 1054 return isInner; 1055 } 1056 llvm_unreachable("unexpected vectorization strategy"); 1057 } 1058 1059 /// Returns parallelization strategy. Any implicit loop in the Linalg operation 1060 /// that is marked "parallel" is a candidate. Whether it is actually converted 1061 /// to a parallel operation depends on the requested strategy. 1062 static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction, 1063 bool isSparse, bool isVector) { 1064 switch (codegen.options.parallelizationStrategy) { 1065 case SparseParallelizationStrategy::kNone: 1066 return false; 1067 case SparseParallelizationStrategy::kDenseOuterLoop: 1068 return isOuter && !isSparse && !isReduction && !isVector; 1069 case SparseParallelizationStrategy::kAnyStorageOuterLoop: 1070 return isOuter && !isReduction && !isVector; 1071 case SparseParallelizationStrategy::kDenseAnyLoop: 1072 return !isSparse && !isReduction && !isVector; 1073 case SparseParallelizationStrategy::kAnyStorageAnyLoop: 1074 return !isReduction && !isVector; 1075 } 1076 llvm_unreachable("unexpected parallelization strategy"); 1077 } 1078 1079 /// Checks unit stride for dense tensors. The iteration graph may have ignored 1080 /// dense access patterns in order to avoid cycles (sparse access patterns are 1081 /// always placed innermost), but that means dense access has become strided. 1082 /// This prevents effective vectorization. 1083 static bool denseUnitStrides(Merger &merger, linalg::GenericOp op, 1084 unsigned idx) { 1085 for (OpOperand *t : op.getInputAndOutputOperands()) { 1086 if (!getSparseTensorEncoding(t->get().getType())) { 1087 auto map = op.getTiedIndexingMap(t); 1088 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 1089 AffineExpr a = map.getResult(d); 1090 // Report non-unit stride if innermost index appears at an outer 1091 // dimension (true non-unit stride) or if the innermost index appears 1092 // in a compound subscript in the innermost dimension. Even if the 1093 // latter is unit stride, it does not play well with scatter/gather. 1094 // TODO: accept unit stride affine innermost like a[i,j+k+1]? 1095 if (a.isFunctionOfDim(idx) && 1096 ((d != rank - 1) || (a.getKind() != AffineExprKind::DimId))) 1097 return false; 1098 } 1099 } 1100 } 1101 return true; 1102 } 1103 1104 /// Generates a for-loop on a single index. 1105 static Operation *genFor(Merger &merger, CodeGen &codegen, 1106 PatternRewriter &rewriter, linalg::GenericOp op, 1107 bool isOuter, bool isInner, unsigned idx, 1108 llvm::BitVector &indices) { 1109 unsigned fb = indices.find_first(); 1110 unsigned tensor = merger.tensor(fb); 1111 assert(idx == merger.index(fb)); 1112 auto iteratorTypes = op.iterator_types().getValue(); 1113 bool isReduction = isReductionIterator(iteratorTypes[idx]); 1114 bool isSparse = merger.isDim(fb, Dim::kSparse); 1115 bool isVector = !codegen.sparseOut && 1116 isVectorFor(codegen, isInner, isSparse) && 1117 denseUnitStrides(merger, op, idx); 1118 bool isParallel = 1119 !codegen.sparseOut && 1120 isParallelFor(codegen, isOuter, isReduction, isSparse, isVector); 1121 1122 // Prepare vector length. 1123 if (isVector) 1124 codegen.curVecLength = codegen.options.vectorLength; 1125 1126 // Loop bounds and increment. 1127 Location loc = op.getLoc(); 1128 Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx]; 1129 Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx]; 1130 Value step = constantIndex(rewriter, loc, codegen.curVecLength); 1131 1132 // Emit a parallel loop. 1133 if (isParallel) { 1134 assert(!isVector); 1135 scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step); 1136 if (isSparse) 1137 codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0]; 1138 else 1139 codegen.loops[idx] = parOp.getInductionVars()[0]; 1140 rewriter.setInsertionPointToStart(parOp.getBody()); 1141 return parOp; 1142 } 1143 1144 // Emit a sequential or vector loop. 1145 SmallVector<Value, 4> operands; 1146 if (codegen.redVal) { 1147 // In a vector loop, bring reduction into SIMD form, if not already. 1148 if (isVector && !codegen.redVal.getType().isa<VectorType>()) { 1149 VectorType vtp = vectorType(codegen, codegen.redVal.getType()); 1150 Value vred = genVectorReducInit(codegen, rewriter, loc, vtp); 1151 updateReduc(merger, codegen, vred); 1152 } 1153 operands.push_back(codegen.redVal); 1154 } 1155 if (codegen.expValues) 1156 operands.push_back(codegen.expCount); 1157 scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands); 1158 if (codegen.redVal) 1159 updateReduc(merger, codegen, forOp.getRegionIterArgs().front()); 1160 if (codegen.expValues) 1161 codegen.expCount = forOp.getRegionIterArgs().back(); 1162 // Assign induction variable to sparse or dense index. 1163 Value iv = forOp.getInductionVar(); 1164 if (isSparse) 1165 codegen.pidxs[tensor][idx] = iv; 1166 else 1167 codegen.loops[idx] = iv; 1168 rewriter.setInsertionPointToStart(forOp.getBody()); 1169 // Share vector iteration mask between all subsequent loads/stores. 1170 if (isVector) 1171 codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step); 1172 return forOp; 1173 } 1174 1175 /// Emit a while-loop for co-iteration over multiple indices. 1176 static Operation *genWhile(Merger &merger, CodeGen &codegen, 1177 PatternRewriter &rewriter, linalg::GenericOp op, 1178 unsigned idx, bool needsUniv, 1179 llvm::BitVector &indices) { 1180 SmallVector<Type, 4> types; 1181 SmallVector<Value, 4> operands; 1182 // Construct the while-loop with a parameter for each index. 1183 Type indexType = rewriter.getIndexType(); 1184 for (unsigned b = 0, be = indices.size(); b < be; b++) { 1185 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1186 unsigned tensor = merger.tensor(b); 1187 assert(idx == merger.index(b)); 1188 types.push_back(indexType); 1189 operands.push_back(codegen.pidxs[tensor][idx]); 1190 } 1191 } 1192 if (codegen.redVal) { 1193 types.push_back(codegen.redVal.getType()); 1194 operands.push_back(codegen.redVal); 1195 } 1196 if (codegen.expValues) { 1197 types.push_back(indexType); 1198 operands.push_back(codegen.expCount); 1199 } 1200 if (needsUniv) { 1201 types.push_back(indexType); 1202 operands.push_back(codegen.loops[idx]); 1203 } 1204 assert(types.size() == operands.size()); 1205 Location loc = op.getLoc(); 1206 scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands); 1207 1208 SmallVector<Location> locs(types.size(), loc); 1209 Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, locs); 1210 Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, types, locs); 1211 1212 // Build the "before" region, which effectively consists 1213 // of a conjunction of "i < upper" tests on all induction. 1214 rewriter.setInsertionPointToStart(&whileOp.getBefore().front()); 1215 Value cond; 1216 unsigned o = 0; 1217 for (unsigned b = 0, be = indices.size(); b < be; b++) { 1218 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1219 unsigned tensor = merger.tensor(b); 1220 assert(idx == merger.index(b)); 1221 Value op1 = before->getArgument(o); 1222 Value op2 = codegen.highs[tensor][idx]; 1223 Value opc = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 1224 op1, op2); 1225 cond = cond ? rewriter.create<arith::AndIOp>(loc, cond, opc) : opc; 1226 codegen.pidxs[tensor][idx] = after->getArgument(o++); 1227 } 1228 } 1229 if (codegen.redVal) 1230 updateReduc(merger, codegen, after->getArgument(o++)); 1231 if (codegen.expValues) 1232 codegen.expCount = after->getArgument(o++); 1233 if (needsUniv) 1234 codegen.loops[idx] = after->getArgument(o++); 1235 assert(o == operands.size()); 1236 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 1237 rewriter.setInsertionPointToStart(&whileOp.getAfter().front()); 1238 return whileOp; 1239 } 1240 1241 /// Generates a for-loop or a while-loop, depending on whether it implements 1242 /// singleton iteration or co-iteration over the given conjunction. 1243 static Operation *genLoop(Merger &merger, CodeGen &codegen, 1244 PatternRewriter &rewriter, linalg::GenericOp op, 1245 std::vector<unsigned> &topSort, unsigned at, 1246 bool needsUniv, llvm::BitVector &indices) { 1247 unsigned idx = topSort[at]; 1248 if (indices.count() == 1) { 1249 bool isOuter = at == 0; 1250 bool isInner = at == topSort.size() - 1; 1251 return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx, 1252 indices); 1253 } 1254 return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices); 1255 } 1256 1257 /// Generates the local variables for this loop, consisting of the sparse 1258 /// indices, restored universal dense index, and dense positions. 1259 static void genLocals(Merger &merger, CodeGen &codegen, 1260 PatternRewriter &rewriter, linalg::GenericOp op, 1261 std::vector<unsigned> &topSort, unsigned at, 1262 bool needsUniv, llvm::BitVector &locals) { 1263 Location loc = op.getLoc(); 1264 unsigned idx = topSort[at]; 1265 1266 // Initialize sparse indices. 1267 Value min; 1268 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1269 if (locals[b] && merger.isDim(b, Dim::kSparse)) { 1270 unsigned tensor = merger.tensor(b); 1271 assert(idx == merger.index(b)); 1272 Value ptr = codegen.indices[tensor][idx]; 1273 Value s = codegen.pidxs[tensor][idx]; 1274 Value load = genLoad(codegen, rewriter, loc, ptr, s); 1275 codegen.idxs[tensor][idx] = load; 1276 if (!needsUniv) { 1277 if (min) { 1278 Value cmp = rewriter.create<arith::CmpIOp>( 1279 loc, arith::CmpIPredicate::ult, load, min); 1280 min = rewriter.create<SelectOp>(loc, cmp, load, min); 1281 } else { 1282 min = load; 1283 } 1284 } 1285 } 1286 } 1287 1288 // Merge dense universal index over minimum. 1289 if (min) { 1290 assert(!needsUniv); 1291 codegen.loops[idx] = min; 1292 } 1293 1294 // Initialize dense positions. Note that we generate dense indices of the 1295 // output tensor unconditionally, since they may not appear in the lattice, 1296 // but may be needed for linearized codegen. 1297 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1298 if ((locals[b] || merger.isOutTensor(b, idx)) && 1299 merger.isDim(b, Dim::kDense)) { 1300 unsigned tensor = merger.tensor(b); 1301 assert(idx == merger.index(b)); 1302 unsigned pat = at; 1303 for (; pat != 0; pat--) 1304 if (codegen.pidxs[tensor][topSort[pat - 1]]) 1305 break; 1306 Value p = (pat == 0) ? constantIndex(rewriter, loc, 0) 1307 : codegen.pidxs[tensor][topSort[pat - 1]]; 1308 codegen.pidxs[tensor][idx] = genAddress( 1309 codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]); 1310 } 1311 } 1312 1313 // Move the insertion indices in lexicographic index order. During access 1314 // pattern expansion, we can skip setting the innermost dimension. 1315 if (codegen.sparseOut && !codegen.expValues) { 1316 Value pos = constantIndex(rewriter, loc, at); 1317 rewriter.create<memref::StoreOp>(loc, codegen.loops[idx], codegen.lexIdx, 1318 pos); 1319 } 1320 } 1321 1322 /// Generates the induction structure for a while-loop. 1323 static void genWhileInduction(Merger &merger, CodeGen &codegen, 1324 PatternRewriter &rewriter, linalg::GenericOp op, 1325 unsigned idx, bool needsUniv, 1326 llvm::BitVector &induction, 1327 scf::WhileOp whileOp) { 1328 Location loc = op.getLoc(); 1329 // Finalize each else branch of all if statements. 1330 if (codegen.redVal || codegen.expValues) { 1331 while (auto ifOp = dyn_cast_or_null<scf::IfOp>( 1332 rewriter.getInsertionBlock()->getParentOp())) { 1333 unsigned y = 0; 1334 SmallVector<Value, 4> yields; 1335 if (codegen.redVal) { 1336 yields.push_back(codegen.redVal); 1337 updateReduc(merger, codegen, ifOp.getResult(y++)); 1338 } 1339 if (codegen.expValues) { 1340 yields.push_back(codegen.expCount); 1341 codegen.expCount = ifOp->getResult(y++); 1342 } 1343 assert(y == yields.size()); 1344 rewriter.create<scf::YieldOp>(loc, yields); 1345 rewriter.setInsertionPointAfter(ifOp); 1346 } 1347 } 1348 rewriter.setInsertionPointToEnd(&whileOp.getAfter().front()); 1349 // Finalize the induction. Note that the induction could be performed 1350 // in the individual if-branches to avoid re-evaluating the conditions. 1351 // However, that would result in a rather elaborate forest of yield 1352 // instructions during code generation. Moreover, performing the induction 1353 // after the if-statements more closely resembles code generated by TACO. 1354 unsigned o = 0; 1355 SmallVector<Value, 4> operands; 1356 Value one = constantIndex(rewriter, loc, 1); 1357 for (unsigned b = 0, be = induction.size(); b < be; b++) { 1358 if (induction[b] && merger.isDim(b, Dim::kSparse)) { 1359 unsigned tensor = merger.tensor(b); 1360 assert(idx == merger.index(b)); 1361 Value op1 = codegen.idxs[tensor][idx]; 1362 Value op2 = codegen.loops[idx]; 1363 Value op3 = codegen.pidxs[tensor][idx]; 1364 Value cmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 1365 op1, op2); 1366 Value add = rewriter.create<arith::AddIOp>(loc, op3, one); 1367 operands.push_back(rewriter.create<SelectOp>(loc, cmp, add, op3)); 1368 codegen.pidxs[tensor][idx] = whileOp->getResult(o++); 1369 } 1370 } 1371 if (codegen.redVal) { 1372 operands.push_back(codegen.redVal); 1373 updateReduc(merger, codegen, whileOp->getResult(o++)); 1374 } 1375 if (codegen.expValues) { 1376 operands.push_back(codegen.expCount); 1377 codegen.expCount = whileOp->getResult(o++); 1378 } 1379 if (needsUniv) { 1380 operands.push_back( 1381 rewriter.create<arith::AddIOp>(loc, codegen.loops[idx], one)); 1382 codegen.loops[idx] = whileOp->getResult(o++); 1383 } 1384 assert(o == operands.size()); 1385 rewriter.create<scf::YieldOp>(loc, operands); 1386 rewriter.setInsertionPointAfter(whileOp); 1387 } 1388 1389 /// Generates the induction structure for a for-loop. 1390 static void genForInduction(Merger &merger, CodeGen &codegen, 1391 PatternRewriter &rewriter, linalg::GenericOp op, 1392 Operation *loop) { 1393 Location loc = op.getLoc(); 1394 unsigned o = 0; 1395 SmallVector<Value, 4> operands; 1396 if (codegen.redVal) { 1397 operands.push_back(codegen.redVal); 1398 updateReduc(merger, codegen, loop->getResult(o++)); 1399 } 1400 if (codegen.expValues) { 1401 operands.push_back(codegen.expCount); 1402 codegen.expCount = loop->getResult(o++); 1403 } 1404 assert(o == operands.size()); 1405 if (o > 0) 1406 rewriter.create<scf::YieldOp>(loc, operands); 1407 rewriter.setInsertionPointAfter(loop); 1408 } 1409 1410 /// Generates a single if-statement within a while-loop. 1411 static scf::IfOp genIf(Merger &merger, CodeGen &codegen, 1412 PatternRewriter &rewriter, linalg::GenericOp op, 1413 unsigned idx, llvm::BitVector &conditions) { 1414 Location loc = op.getLoc(); 1415 SmallVector<Type, 4> types; 1416 Value cond; 1417 for (unsigned b = 0, be = conditions.size(); b < be; b++) { 1418 if (conditions[b]) { 1419 unsigned tensor = merger.tensor(b); 1420 assert(idx == merger.index(b)); 1421 Value clause; 1422 if (merger.isDim(b, Dim::kSparse)) { 1423 Value op1 = codegen.idxs[tensor][idx]; 1424 Value op2 = codegen.loops[idx]; 1425 clause = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 1426 op1, op2); 1427 } else { 1428 clause = constantI1(rewriter, loc, true); 1429 } 1430 cond = cond ? rewriter.create<arith::AndIOp>(loc, cond, clause) : clause; 1431 } 1432 } 1433 if (codegen.redVal) 1434 types.push_back(codegen.redVal.getType()); 1435 if (codegen.expValues) 1436 types.push_back(rewriter.getIndexType()); 1437 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, types, cond, /*else=*/true); 1438 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); 1439 return ifOp; 1440 } 1441 1442 /// Generates end of true branch of if-statement within a while-loop. 1443 static void endIf(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1444 linalg::GenericOp op, scf::IfOp ifOp, Operation *loop, 1445 Value redInput, Value cntInput) { 1446 SmallVector<Value, 4> operands; 1447 if (codegen.redVal) { 1448 operands.push_back(codegen.redVal); 1449 updateReduc(merger, codegen, redInput); 1450 } 1451 if (codegen.expValues) { 1452 operands.push_back(codegen.expCount); 1453 codegen.expCount = cntInput; 1454 } 1455 if (!operands.empty()) 1456 rewriter.create<scf::YieldOp>(op.getLoc(), operands); 1457 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); 1458 } 1459 1460 //===----------------------------------------------------------------------===// 1461 // Sparse compiler synthesis methods (loop sequence). 1462 //===----------------------------------------------------------------------===// 1463 1464 /// Starts a loop sequence at given level. Returns true if 1465 /// the universal loop index must be maintained at this level. 1466 static bool startLoopSeq(Merger &merger, CodeGen &codegen, 1467 PatternRewriter &rewriter, linalg::GenericOp op, 1468 std::vector<unsigned> &topSort, unsigned exp, 1469 unsigned at, unsigned idx, unsigned ldx, 1470 unsigned lts) { 1471 assert(codegen.curVecLength == 1); 1472 assert(!codegen.loops[idx]); 1473 // Emit invariants at this loop sequence level. 1474 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*atStart=*/true); 1475 // Emit access pattern expansion for sparse tensor output. 1476 genExpansion(merger, codegen, rewriter, op, at, /*atStart=*/true); 1477 // Emit further intitialization at this loop sequence level. 1478 unsigned l0 = merger.set(lts)[0]; 1479 bool needsUniv = 1480 genInit(merger, codegen, rewriter, op, topSort, at, merger.lat(l0).bits); 1481 // Maintain the universal index only if it is actually 1482 // consumed by a subsequent lattice point. 1483 if (needsUniv) { 1484 unsigned lsize = merger.set(lts).size(); 1485 for (unsigned i = 1; i < lsize; i++) { 1486 unsigned li = merger.set(lts)[i]; 1487 if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) 1488 return true; 1489 } 1490 } 1491 return false; 1492 } 1493 1494 /// Starts a single loop in current sequence. 1495 static Operation *startLoop(Merger &merger, CodeGen &codegen, 1496 PatternRewriter &rewriter, linalg::GenericOp op, 1497 std::vector<unsigned> &topSort, unsigned at, 1498 unsigned li, bool needsUniv) { 1499 assert(codegen.curVecLength == 1); 1500 // Emit the for/while-loop control. 1501 Operation *loop = genLoop(merger, codegen, rewriter, op, topSort, at, 1502 needsUniv, merger.lat(li).simple); 1503 // Emit the locals for this loop. 1504 genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, 1505 merger.lat(li).bits); 1506 return loop; 1507 } 1508 1509 /// Ends a single loop in current sequence. Returns new values for needsUniv. 1510 static bool endLoop(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1511 linalg::GenericOp op, Operation *loop, unsigned idx, 1512 unsigned li, bool needsUniv) { 1513 codegen.curVecLength = 1; 1514 // End a while-loop. 1515 if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) { 1516 genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv, 1517 merger.lat(li).bits, whileOp); 1518 return needsUniv; 1519 } 1520 // End a for-loop. 1521 genForInduction(merger, codegen, rewriter, op, loop); 1522 return false; 1523 } 1524 1525 /// Ends a loop sequence at given level. 1526 static void endLoopSeq(Merger &merger, CodeGen &codegen, 1527 PatternRewriter &rewriter, linalg::GenericOp op, 1528 unsigned exp, unsigned at, unsigned idx, unsigned ldx) { 1529 assert(codegen.curVecLength == 1); 1530 codegen.loops[idx] = Value(); 1531 // Bring a pending reduction back from SIMD form when sequence ends. 1532 if (codegen.redVal) 1533 if (auto vtp = codegen.redVal.getType().dyn_cast<VectorType>()) 1534 updateReduc(merger, codegen, 1535 genVectorReducEnd(codegen, rewriter, op.getLoc(), vtp)); 1536 // Unmark bookkeeping of invariants and loop index. 1537 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*atStart=*/false); 1538 // Finalize access pattern expansion for sparse tensor output. 1539 genExpansion(merger, codegen, rewriter, op, at, /*atStart=*/false); 1540 } 1541 1542 /// Recursively generates code while computing iteration lattices in order 1543 /// to manage the complexity of implementing co-iteration over unions 1544 /// and intersections of sparse iterations spaces. 1545 static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1546 linalg::GenericOp op, std::vector<unsigned> &topSort, 1547 unsigned exp, unsigned at) { 1548 // At each leaf, assign remaining tensor (sub)expression to output tensor. 1549 if (at == topSort.size()) { 1550 Value rhs = genExp(merger, codegen, rewriter, op, exp); 1551 genTensorStore(merger, codegen, rewriter, op, rhs); 1552 return; 1553 } 1554 1555 // Construct iteration lattices for current loop index, with L0 at top. 1556 unsigned idx = topSort[at]; 1557 unsigned ldx = at == 0 ? -1u : topSort[at - 1]; 1558 unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx)); 1559 1560 // Start a loop sequence. 1561 bool needsUniv = startLoopSeq(merger, codegen, rewriter, op, topSort, exp, at, 1562 idx, ldx, lts); 1563 1564 // Emit a loop for every lattice point L0 >= Li in this loop sequence. 1565 unsigned lsize = merger.set(lts).size(); 1566 for (unsigned i = 0; i < lsize; i++) { 1567 // Start a loop. 1568 unsigned li = merger.set(lts)[i]; 1569 Operation *loop = 1570 startLoop(merger, codegen, rewriter, op, topSort, at, li, needsUniv); 1571 1572 // Visit all lattices points with Li >= Lj to generate the 1573 // loop-body, possibly with if statements for coiteration. 1574 Value redInput = codegen.redVal; 1575 Value cntInput = codegen.expCount; 1576 bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr; 1577 for (unsigned j = 0; j < lsize; j++) { 1578 unsigned lj = merger.set(lts)[j]; 1579 unsigned ej = merger.lat(lj).exp; 1580 if (li == lj || merger.latGT(li, lj)) { 1581 // Recurse into body of each branch. 1582 if (isWhile) { 1583 scf::IfOp ifOp = 1584 genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple); 1585 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1586 endIf(merger, codegen, rewriter, op, ifOp, loop, redInput, cntInput); 1587 } else { 1588 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1589 } 1590 } 1591 } 1592 1593 // End a loop. 1594 needsUniv = 1595 endLoop(merger, codegen, rewriter, op, loop, idx, li, needsUniv); 1596 } 1597 1598 // End a loop sequence. 1599 endLoopSeq(merger, codegen, rewriter, op, exp, at, idx, ldx); 1600 } 1601 1602 /// Converts the result computed by the sparse kernel into the required form. 1603 static void genResult(Merger &merger, CodeGen &codegen, 1604 PatternRewriter &rewriter, linalg::GenericOp op) { 1605 OpOperand *lhs = op.getOutputOperand(0); 1606 Type resType = lhs->get().getType(); 1607 if (getSparseTensorEncoding(resType)) { 1608 // The sparse tensor rematerializes from the original sparse tensor's 1609 // underlying sparse storage format. 1610 rewriter.replaceOpWithNewOp<LoadOp>(op, resType, lhs->get(), 1611 codegen.sparseOut == lhs); 1612 } else { 1613 // To rematerialize an non-annotated tensor, simply load it 1614 // from the bufferized value. 1615 Value val = codegen.buffers.back(); // value array 1616 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val); 1617 } 1618 } 1619 1620 //===----------------------------------------------------------------------===// 1621 // Sparse compiler rewriting methods. 1622 //===----------------------------------------------------------------------===// 1623 1624 namespace { 1625 1626 /// Sparse rewriting rule for generic Lingalg operation. 1627 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> { 1628 public: 1629 GenericOpSparsifier(MLIRContext *context, SparsificationOptions o) 1630 : OpRewritePattern<linalg::GenericOp>(context), options(o) {} 1631 1632 LogicalResult matchAndRewrite(linalg::GenericOp op, 1633 PatternRewriter &rewriter) const override { 1634 // Detects sparse annotations and translate the per-dimension sparsity 1635 // information for all tensors to loop indices in the kernel. 1636 assert(op.getNumOutputs() == 1); 1637 unsigned numTensors = op.getNumInputsAndOutputs(); 1638 unsigned numLoops = op.iterator_types().getValue().size(); 1639 Merger merger(numTensors, numLoops); 1640 if (!findSparseAnnotations(merger, op)) 1641 return failure(); 1642 1643 // Computes a topologically sorted iteration graph to ensure 1644 // tensors are visited in natural index order. Fails on cycles. 1645 // This assumes that higher-level passes have already put the 1646 // tensors in each tensor expression in a feasible order. 1647 std::vector<unsigned> topSort; 1648 if (!computeIterationGraph(merger, op, topSort, 1649 SortMask::kIncludeUndef | 1650 SortMask::kIncludeDense) && 1651 !computeIterationGraph(merger, op, topSort, SortMask::kIncludeUndef) && 1652 !computeIterationGraph(merger, op, topSort, SortMask::kIncludeDense) && 1653 !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly)) 1654 return failure(); 1655 1656 // Builds the tensor expression for the Linalg operation in SSA form. 1657 Optional<unsigned> optExp = merger.buildTensorExpFromLinalg(op); 1658 if (!optExp.hasValue()) 1659 return failure(); 1660 unsigned exp = optExp.getValue(); 1661 1662 // Rejects an inadmissable tensor expression. 1663 OpOperand *sparseOut = nullptr; 1664 unsigned outerParNest = 0; 1665 if (!isAdmissableTensorExp(merger, op, topSort, exp, &sparseOut, 1666 outerParNest)) 1667 return failure(); 1668 1669 // Recursively generates code. 1670 merger.setHasSparseOut(sparseOut != nullptr); 1671 CodeGen codegen(options, numTensors, numLoops, sparseOut, outerParNest); 1672 genBuffers(merger, codegen, rewriter, op); 1673 genStmt(merger, codegen, rewriter, op, topSort, exp, 0); 1674 genResult(merger, codegen, rewriter, op); 1675 return success(); 1676 } 1677 1678 private: 1679 /// Options to control sparse code generation. 1680 SparsificationOptions options; 1681 }; 1682 1683 } // namespace 1684 1685 /// Populates the given patterns list with rewriting rules required for 1686 /// the sparsification of linear algebra operations. 1687 void mlir::populateSparsificationPatterns( 1688 RewritePatternSet &patterns, const SparsificationOptions &options) { 1689 patterns.add<GenericOpSparsifier>(patterns.getContext(), options); 1690 } 1691