1 //===- Sparsification.cpp - Implementation of sparsification --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements converting sparse tensor types to actual sparse code. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "CodegenUtils.h" 14 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 18 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 19 #include "mlir/Dialect/Linalg/IR/Linalg.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/MemRef/IR/MemRef.h" 22 #include "mlir/Dialect/SCF/SCF.h" 23 #include "mlir/Dialect/SCF/Transforms.h" 24 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 25 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 26 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 27 #include "mlir/Dialect/Vector/IR/VectorOps.h" 28 #include "mlir/IR/Matchers.h" 29 #include "mlir/IR/TensorEncoding.h" 30 #include "llvm/ADT/SmallBitVector.h" 31 32 using namespace mlir; 33 using namespace mlir::sparse_tensor; 34 35 //===----------------------------------------------------------------------===// 36 // Declarations of data structures. 37 //===----------------------------------------------------------------------===// 38 39 namespace { 40 41 // Iteration graph sorting. 42 enum SortMask { kSparseOnly = 0x0, kIncludeDense = 0x1, kIncludeUndef = 0x2 }; 43 44 // Reduction kinds. 45 enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor }; 46 47 // Code generation. 48 struct CodeGen { 49 CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops, 50 OpOperand *op, unsigned nest) 51 : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors), 52 pointers(numTensors, std::vector<Value>(numLoops)), 53 indices(numTensors, std::vector<Value>(numLoops)), 54 highs(numTensors, std::vector<Value>(numLoops)), 55 pidxs(numTensors, std::vector<Value>(numLoops)), 56 idxs(numTensors, std::vector<Value>(numLoops)), redVal(), sparseOut(op), 57 outerParNest(nest), lexIdx(), expValues(), expFilled(), expAdded(), 58 expCount(), curVecMask() {} 59 /// Sparsification options. 60 SparsificationOptions options; 61 /// Universal dense indices and upper bounds (by index). The loops array 62 /// is updated with the value of the universal dense index in the current 63 /// loop. The sizes array is set once with the inferred dimension sizes. 64 std::vector<Value> loops; 65 std::vector<Value> sizes; 66 /// Buffers for storing dense and sparse numerical values (by tensor). 67 /// This array is set once during bufferization of all tensors. 68 std::vector<Value> buffers; 69 /// Sparse storage schemes (1-D): pointers and indices (by tensor and index). 70 /// This array is set once during bufferization of all sparse tensors. 71 std::vector<std::vector<Value>> pointers; 72 std::vector<std::vector<Value>> indices; 73 /// Sparse iteration information (by tensor and index). These arrays 74 /// are updated to remain current within the current loop. 75 std::vector<std::vector<Value>> highs; 76 std::vector<std::vector<Value>> pidxs; 77 std::vector<std::vector<Value>> idxs; 78 /// Current reduction, updated during code generation. When indices of a 79 /// reduction are exhausted, all inner loops can use a scalarized reduction. 80 unsigned redExp = -1u; 81 Value redVal; 82 Reduction redKind = kNoReduc; 83 // Sparse tensor as output. Implemented either through direct injective 84 // insertion in lexicographic index order (where indices are updated 85 // in the temporary array `lexIdx`) or through access pattern expansion 86 // in the innermost loop nest (`expValues` through `expCount`). 87 OpOperand *sparseOut; 88 unsigned outerParNest; 89 Value lexIdx; 90 Value expValues; 91 Value expFilled; 92 Value expAdded; 93 Value expCount; 94 // Current vector length and mask. 95 unsigned curVecLength = 1; 96 Value curVecMask; 97 }; 98 99 } // namespace 100 101 //===----------------------------------------------------------------------===// 102 // Sparse compiler analysis methods. 103 //===----------------------------------------------------------------------===// 104 105 /// Helper method to apply dimension ordering permutation. 106 static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d) { 107 if (enc) { 108 auto order = enc.getDimOrdering(); 109 if (order) { 110 assert(order.isPermutation()); 111 return order.getDimPosition(d); 112 } 113 } 114 return d; 115 } 116 117 /// Helper method to translate dim level type to internal representation. 118 static Dim toDim(const SparseTensorEncodingAttr &enc, unsigned d) { 119 if (enc) { 120 SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d]; 121 if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed) 122 return Dim::kSparse; 123 if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton) 124 return Dim::kSingle; 125 } 126 return Dim::kDense; 127 } 128 129 /// Helper method to inspect affine expressions. Rejects cases where the 130 /// same index is used more than once. Also rejects affine expressions 131 /// that are not a direct index for annotated tensors. 132 // TODO: accept more affine cases for sparse tensors 133 static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a, Dim dim, 134 bool isDense) { 135 switch (a.getKind()) { 136 case AffineExprKind::DimId: { 137 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 138 if (!merger.isDim(tensor, idx, Dim::kUndef)) 139 return false; // used more than once 140 merger.setDim(tensor, idx, dim); 141 return true; 142 } 143 case AffineExprKind::Add: 144 case AffineExprKind::Mul: { 145 if (!isDense) 146 return false; 147 auto binOp = a.cast<AffineBinaryOpExpr>(); 148 return findAffine(merger, tensor, binOp.getLHS(), dim, isDense) && 149 findAffine(merger, tensor, binOp.getRHS(), dim, isDense); 150 } 151 case AffineExprKind::Constant: 152 return isDense; 153 default: 154 return false; 155 } 156 } 157 158 /// Helper method to inspect sparse encodings in the tensor types. 159 /// Fills the per-dimension sparsity information for all tensors. 160 /// Returns true if the sparse annotations and affine subscript 161 /// expressions of all tensors are admissable. Returns false if 162 /// no annotations are found or inadmissable constructs occur. 163 static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) { 164 bool annotated = false; 165 for (OpOperand *t : op.getInputAndOutputOperands()) { 166 auto map = op.getTiedIndexingMap(t); 167 auto enc = getSparseTensorEncoding(t->get().getType()); 168 if (enc) 169 annotated = true; 170 assert(map.getNumResults() == op.getRank(t)); 171 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 172 unsigned tensor = t->getOperandNumber(); 173 AffineExpr a = map.getResult(perm(enc, d)); 174 if (!findAffine(merger, tensor, a, toDim(enc, d), !enc)) 175 return false; // inadmissable affine expression 176 } 177 } 178 return annotated; 179 } 180 181 /// A DFS helper to compute a topological sort. Note that recursion is 182 /// bounded by the number of implicit loops, which is always small. 183 /// Returns false when a cycle is detected. 184 static bool topSortDFS(unsigned i, std::vector<unsigned> &visit, 185 std::vector<unsigned> &topSort, 186 std::vector<std::vector<bool>> &adjM) { 187 if (visit[i] != 0) 188 return visit[i] != 1; // 1 denotes cycle! 189 visit[i] = 1; 190 for (unsigned j = 0, e = visit.size(); j < e; j++) 191 if (adjM[i][j]) 192 if (!topSortDFS(j, visit, topSort, adjM)) 193 return false; 194 visit[i] = 2; 195 topSort.push_back(i); 196 return true; 197 } 198 199 /// Helper method to add all constraints from the indices in one affine 200 /// expression before all indices in the other affine expression. For 201 /// example i0+i1 < i2+i3+1 yields i0<i2, i0<i3, i1<i2, and i1<i3. 202 static void addAffineOrderings(std::vector<std::vector<bool>> &adjM, 203 AffineExpr a, AffineExpr b, unsigned fidx) { 204 switch (a.getKind()) { 205 case AffineExprKind::DimId: { 206 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 207 if (b) 208 addAffineOrderings(adjM, b, AffineExpr(), idx); 209 else 210 adjM[fidx][idx] = true; 211 break; 212 } 213 case AffineExprKind::Add: 214 case AffineExprKind::Mul: { 215 auto binOp = a.cast<AffineBinaryOpExpr>(); 216 addAffineOrderings(adjM, binOp.getLHS(), b, fidx); 217 addAffineOrderings(adjM, binOp.getRHS(), b, fidx); 218 break; 219 } 220 default: 221 break; 222 } 223 } 224 225 /// Computes a topologically sorted iteration graph for the linalg operation. 226 /// Ensures all tensors are visited in natural index order. This is essential 227 /// for sparse storage formats since these only support access along fixed 228 /// dimensions. Even for dense storage formats, however, the natural index 229 /// order yields innermost unit-stride access with better spatial locality. 230 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op, 231 std::vector<unsigned> &topSort, 232 unsigned mask) { 233 // Set up an n x n from/to adjacency matrix of the iteration graph 234 // for the implicit loop indices i_0 .. i_n-1. 235 unsigned n = op.getNumLoops(); 236 std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false)); 237 238 // Iterate over the indexing maps of every tensor in the tensor expression. 239 for (OpOperand *t : op.getInputAndOutputOperands()) { 240 auto map = op.getTiedIndexingMap(t); 241 auto enc = getSparseTensorEncoding(t->get().getType()); 242 assert(map.getNumDims() == n); 243 // Skip dense tensor constraints when not requested. 244 if (!(mask & SortMask::kIncludeDense) && !enc) 245 continue; 246 // Each tensor expression and optional dimension ordering (row-major 247 // by default) puts an ordering constraint on the loop indices. For 248 // example, the tensor expresion A_ijk forces the ordering i < j < k 249 // on the loop indices if no explicit dimension ordering is given. 250 for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) { 251 AffineExpr f = map.getResult(perm(enc, d - 1)); 252 AffineExpr t = map.getResult(perm(enc, d)); 253 addAffineOrderings(adjM, f, t, 0); 254 } 255 // Push unrelated loops into sparse iteration space, so these 256 // will be skipped more often. 257 if (mask & SortMask::kIncludeUndef) { 258 unsigned tensor = t->getOperandNumber(); 259 for (unsigned i = 0; i < n; i++) 260 if (merger.isDim(tensor, i, Dim::kSparse)) 261 for (unsigned j = 0; j < n; j++) 262 if (merger.isDim(tensor, j, Dim::kUndef)) 263 adjM[i][j] = true; 264 } 265 } 266 267 // Topologically sort the iteration graph to determine loop order. 268 // Report failure for a cyclic iteration graph. 269 topSort.clear(); 270 topSort.reserve(n); 271 std::vector<unsigned> visit(n, 0); 272 for (unsigned i = 0; i < n; i++) 273 if (visit[i] == 0) 274 if (!topSortDFS(i, visit, topSort, adjM)) 275 return false; // cycle! 276 std::reverse(std::begin(topSort), std::end(topSort)); 277 return true; 278 } 279 280 /// Returns true if tensor has an in-place annotation. 281 static bool isInPlace(Value val) { 282 if (auto arg = val.dyn_cast<BlockArgument>()) 283 if (auto funcOp = dyn_cast<FuncOp>(arg.getOwner()->getParentOp())) 284 if (auto attr = funcOp.getArgAttrOfType<BoolAttr>( 285 arg.getArgNumber(), 286 bufferization::BufferizableOpInterface::kInplaceableAttrName)) 287 return attr.getValue(); 288 return false; 289 } 290 291 /// Returns true if tensor materializes uninitialized into the computation. 292 static bool isMaterializing(Value val) { 293 return val.getDefiningOp<linalg::InitTensorOp>() || 294 val.getDefiningOp<InitOp>(); 295 } 296 297 /// Returns true when the tensor expression is admissable for codegen. 298 /// Since all sparse input tensors are admissable, we just need to check 299 /// whether the out tensor in the tensor expression codegen is admissable. 300 /// Sets `sparseOut` to the tensor and `outerParNest` to the outer injective 301 /// nesting depth when a "truly dynamic" sparse tensor output occurs. 302 static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op, 303 std::vector<unsigned> &topSort, unsigned exp, 304 OpOperand **sparseOut, 305 unsigned &outerParNest) { 306 OpOperand *lhs = op.getOutputOperand(0); 307 unsigned tensor = lhs->getOperandNumber(); 308 auto enc = getSparseTensorEncoding(lhs->get().getType()); 309 // An non-annotated output tensor is assumed dense, and becomes a random 310 // access n-dim memref. Admissable since insertions cannot occur. 311 if (!enc) 312 return true; 313 // An all-dense annotated "sparse" output tensor becomes a linearized random 314 // access 1-dim memref. Also admissable since insertions cannot occur. 315 bool allDense = true; 316 auto iteratorTypes = op.iterator_types().getValue(); 317 unsigned numLoops = iteratorTypes.size(); 318 for (unsigned i = 0; i < numLoops; i++) 319 if (merger.isDim(tensor, i, Dim::kSparse)) { 320 allDense = false; 321 break; 322 } 323 if (allDense) 324 return true; 325 // A tensor expression with a sparse output tensor that changes its values 326 // but not its nonzero structure, an operation called "simply dynamic" in 327 // [Bik96,Ch9], is also admissable without special codegen, provided 328 // the tensor's underlying sparse storage scheme can be modified in place. 329 if (merger.isSingleCondition(tensor, exp) && isInPlace(lhs->get())) 330 return true; 331 // Accept "truly dynamic" if the output tensor materializes uninitialized 332 // into the computation and insertions occur in lexicographic index order. 333 if (isMaterializing(lhs->get())) { 334 unsigned nest = 0; 335 for (unsigned i = 0; i < numLoops; i++) { 336 if (isReductionIterator(iteratorTypes[topSort[i]])) 337 break; // terminate at first reduction 338 nest++; 339 } 340 // Determine admissable dynamic insertion situations: 341 // (1) fully injective, since there are no reductions, 342 // (2) admissable 1-d expansion in innermost dimension. 343 if (nest >= op.getRank(lhs) - 1) { 344 *sparseOut = lhs; 345 outerParNest = nest; 346 return true; 347 } 348 } 349 return false; 350 } 351 352 //===----------------------------------------------------------------------===// 353 // Sparse compiler synthesis methods (reductions). 354 //===----------------------------------------------------------------------===// 355 356 /// Maps reduction kind to vector::CombiningKind. 357 static vector::CombiningKind getCombiningKind(Reduction kind) { 358 switch (kind) { 359 case kNoReduc: 360 break; 361 case kSum: 362 return vector::CombiningKind::ADD; 363 case kProduct: 364 return vector::CombiningKind::MUL; 365 case kAnd: 366 return vector::CombiningKind::AND; 367 case kOr: 368 return vector::CombiningKind::OR; 369 case kXor: 370 return vector::CombiningKind::XOR; 371 } 372 llvm_unreachable("unknown reduction kind"); 373 } 374 375 /// Maps operation to reduction. 376 static Reduction getReduction(Kind kind) { 377 switch (kind) { 378 case Kind::kAddF: 379 case Kind::kAddI: 380 case Kind::kSubF: 381 case Kind::kSubI: 382 return kSum; 383 case Kind::kMulF: 384 case Kind::kMulI: 385 return kProduct; 386 case Kind::kAndI: 387 return kAnd; 388 case Kind::kOrI: 389 return kOr; 390 case Kind::kXorI: 391 return kXor; 392 default: 393 llvm_unreachable("unexpected reduction operator"); 394 } 395 } 396 397 /// Generates an initial value for a vector reduction, following the scheme 398 /// given in Chapter 5 of "The Software Vectorization Handbook", where the 399 /// initial scalar value is correctly embedded in the vector reduction value, 400 /// and a straightforward horizontal reduction will complete the operation. 401 static Value genVectorReducInit(CodeGen &codegen, PatternRewriter &rewriter, 402 Location loc, VectorType vtp) { 403 Value r = codegen.redVal; 404 switch (codegen.redKind) { 405 case kNoReduc: 406 break; 407 case kSum: 408 case kXor: 409 // Initialize reduction vector to: | 0 | .. | 0 | r | 410 return rewriter.create<vector::InsertElementOp>( 411 loc, r, constantZero(rewriter, loc, vtp), 412 constantIndex(rewriter, loc, 0)); 413 case kProduct: 414 // Initialize reduction vector to: | 1 | .. | 1 | r | 415 return rewriter.create<vector::InsertElementOp>( 416 loc, r, constantOne(rewriter, loc, vtp), 417 constantIndex(rewriter, loc, 0)); 418 case kAnd: 419 case kOr: 420 // Initialize reduction vector to: | r | .. | r | r | 421 return rewriter.create<vector::BroadcastOp>(loc, vtp, r); 422 } 423 llvm_unreachable("unknown reduction kind"); 424 } 425 426 /// Generates final value for a vector reduction. 427 static Value genVectorReducEnd(CodeGen &codegen, PatternRewriter &rewriter, 428 Location loc, VectorType vtp) { 429 vector::CombiningKind kind = getCombiningKind(codegen.redKind); 430 return rewriter.create<vector::ReductionOp>(loc, kind, codegen.redVal); 431 } 432 433 /// Updates scalarized reduction value. 434 static void updateReduc(Merger &merger, CodeGen &codegen, Value reduc) { 435 assert(codegen.redKind != kNoReduc); 436 codegen.redVal = merger.exp(codegen.redExp).val = reduc; 437 } 438 439 //===----------------------------------------------------------------------===// 440 // Sparse compiler synthesis methods (statements and expressions). 441 //===----------------------------------------------------------------------===// 442 443 /// Generates buffer for the output tensor. Note that all sparse kernels 444 /// assume that when all elements are written to (viz. x(i) = y(i) * z(i)), 445 /// the output buffer is already initialized to all zeroes and only nonzeroes 446 /// values are computed and written out. For updates (viz. x(i) += y(i) * z(i)), 447 /// only nonzeroes values are used for the updates and no assumption on the 448 /// original contents of the output buffer is necessary.. 449 static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter, 450 linalg::GenericOp op, MemRefType denseTp, 451 ArrayRef<Value> args) { 452 Location loc = op.getLoc(); 453 Value tensor = op.getOutputOperand(0)->get(); 454 // The output tensor simply could materialize from the buffer that will 455 // be generated for the tensor present in the outs() clause. This has 456 // the major advantage that the sparse kernel only updates the nonzero 457 // positions for the output tensor. 458 if (isInPlace(tensor)) 459 return rewriter.create<bufferization::ToMemrefOp>(loc, denseTp, tensor); 460 // By default, a new buffer is allocated which is initialized to the 461 // tensor defined in the outs() clause. This is always correct but 462 // introduces a dense initialization component that may negatively 463 // impact the running complexity of the sparse kernel. If the tensor 464 // materializes into the computation, we need to preserve the zero 465 // initialization assumption of all sparse output buffers. 466 Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args); 467 if (isMaterializing(tensor)) { 468 Value zero = constantZero(rewriter, loc, denseTp.getElementType()); 469 rewriter.create<linalg::FillOp>(loc, zero, alloc); 470 } else { 471 Value init = 472 rewriter.create<bufferization::ToMemrefOp>(loc, denseTp, tensor); 473 rewriter.create<memref::CopyOp>(loc, init, alloc); 474 } 475 return alloc; 476 } 477 478 /// Local bufferization of all dense and sparse data structures. 479 /// This code enables testing the first prototype sparse compiler. 480 // TODO: replace this with a proliferated bufferization strategy 481 static void genBuffers(Merger &merger, CodeGen &codegen, 482 PatternRewriter &rewriter, linalg::GenericOp op) { 483 Location loc = op.getLoc(); 484 assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1); 485 // For every tensor, find lower and upper bound on dimensions, set the 486 // same bounds on loop indices, and obtain dense or sparse buffer(s). 487 SmallVector<Value, 4> args; 488 for (OpOperand *t : op.getInputAndOutputOperands()) { 489 unsigned tensor = t->getOperandNumber(); 490 auto shape = op.getShape(t); 491 auto map = op.getTiedIndexingMap(t); 492 auto enc = getSparseTensorEncoding(t->get().getType()); 493 // Scan all dimensions of current tensor. 494 args.clear(); 495 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 496 AffineExpr a = map.getResult(perm(enc, d)); 497 if (a.getKind() != AffineExprKind::DimId) 498 continue; // compound 499 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 500 // Handle sparse storage schemes. 501 if (merger.isDim(tensor, idx, Dim::kSparse)) { 502 auto dynShape = {ShapedType::kDynamicSize}; 503 auto ptrTp = 504 MemRefType::get(dynShape, getPointerOverheadType(rewriter, enc)); 505 auto indTp = 506 MemRefType::get(dynShape, getIndexOverheadType(rewriter, enc)); 507 Value dim = constantIndex(rewriter, loc, d); 508 // Generate sparse primitives to obtains pointer and indices. 509 codegen.pointers[tensor][idx] = 510 rewriter.create<ToPointersOp>(loc, ptrTp, t->get(), dim); 511 codegen.indices[tensor][idx] = 512 rewriter.create<ToIndicesOp>(loc, indTp, t->get(), dim); 513 } 514 // Find upper bound in current dimension. 515 unsigned p = perm(enc, d); 516 Value up = linalg::createOrFoldDimOp(rewriter, loc, t->get(), p); 517 if (ShapedType::isDynamic(shape[p])) 518 args.push_back(up); 519 assert(codegen.highs[tensor][idx] == nullptr); 520 codegen.sizes[idx] = codegen.highs[tensor][idx] = up; 521 } 522 // Perform the required bufferization. Dense inputs materialize 523 // from the input tensors. Dense outputs need special handling. 524 // Sparse inputs use sparse primitives to obtain the values. 525 // We also accept in-place all-dense annotated "sparse" outputs. 526 Type elementType = getElementTypeOrSelf(t->get().getType()); 527 if (!enc) { 528 // Non-annotated dense tensors. 529 auto denseTp = MemRefType::get(shape, elementType); 530 if (tensor < op.getNumInputs()) 531 codegen.buffers[tensor] = 532 rewriter.create<bufferization::ToMemrefOp>(loc, denseTp, t->get()); 533 else 534 codegen.buffers[tensor] = 535 genOutputBuffer(codegen, rewriter, op, denseTp, args); 536 } else if (t == codegen.sparseOut) { 537 // True sparse output needs a lexIdx array. 538 Value rank = constantIndex(rewriter, loc, op.getRank(t)); 539 auto dynShape = {ShapedType::kDynamicSize}; 540 auto memTp = MemRefType::get(dynShape, rewriter.getIndexType()); 541 codegen.lexIdx = rewriter.create<memref::AllocaOp>(loc, memTp, rank); 542 } else { 543 // Annotated sparse tensors. 544 auto dynShape = {ShapedType::kDynamicSize}; 545 auto sparseTp = MemRefType::get(dynShape, elementType); 546 codegen.buffers[tensor] = 547 rewriter.create<ToValuesOp>(loc, sparseTp, t->get()); 548 } 549 } 550 } 551 552 /// Constructs vector type. 553 static VectorType vectorType(CodeGen &codegen, Type etp) { 554 return VectorType::get(codegen.curVecLength, etp); 555 } 556 557 /// Constructs vector type from pointer. 558 static VectorType vectorType(CodeGen &codegen, Value ptr) { 559 return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType()); 560 } 561 562 /// Constructs vector iteration mask. 563 static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter, 564 Value iv, Value lo, Value hi, Value step) { 565 Location loc = iv.getLoc(); 566 VectorType mtp = vectorType(codegen, rewriter.getI1Type()); 567 // Special case if the vector length evenly divides the trip count (for 568 // example, "for i = 0, 128, 16"). A constant all-true mask is generated 569 // so that all subsequent masked memory operations are immediately folded 570 // into unconditional memory operations. 571 IntegerAttr loInt, hiInt, stepInt; 572 if (matchPattern(lo, m_Constant(&loInt)) && 573 matchPattern(hi, m_Constant(&hiInt)) && 574 matchPattern(step, m_Constant(&stepInt))) { 575 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) 576 return rewriter.create<vector::BroadcastOp>( 577 loc, mtp, constantI1(rewriter, loc, true)); 578 } 579 // Otherwise, generate a vector mask that avoids overrunning the upperbound 580 // during vector execution. Here we rely on subsequent loop optimizations to 581 // avoid executing the mask in all iterations, for example, by splitting the 582 // loop into an unconditional vector loop and a scalar cleanup loop. 583 auto minMap = AffineMap::get( 584 /*dimCount=*/2, /*symbolCount=*/1, 585 {rewriter.getAffineSymbolExpr(0), 586 rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)}, 587 rewriter.getContext()); 588 Value end = 589 rewriter.createOrFold<AffineMinOp>(loc, minMap, ValueRange{hi, iv, step}); 590 return rewriter.create<vector::CreateMaskOp>(loc, mtp, end); 591 } 592 593 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi]. 594 static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter, 595 Value ptr, ArrayRef<Value> args) { 596 Location loc = ptr.getLoc(); 597 VectorType vtp = vectorType(codegen, ptr); 598 Value pass = constantZero(rewriter, loc, vtp); 599 if (args.back().getType().isa<VectorType>()) { 600 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 601 Value indexVec = args.back(); 602 scalarArgs.back() = constantIndex(rewriter, loc, 0); 603 return rewriter.create<vector::GatherOp>( 604 loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass); 605 } 606 return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args, 607 codegen.curVecMask, pass); 608 } 609 610 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs. 611 static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter, 612 Value rhs, Value ptr, ArrayRef<Value> args) { 613 Location loc = ptr.getLoc(); 614 if (args.back().getType().isa<VectorType>()) { 615 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 616 Value indexVec = args.back(); 617 scalarArgs.back() = constantIndex(rewriter, loc, 0); 618 rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, 619 codegen.curVecMask, rhs); 620 return; 621 } 622 rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask, 623 rhs); 624 } 625 626 /// Generates a vectorized invariant. Here we rely on subsequent loop 627 /// optimizations to hoist the invariant broadcast out of the vector loop. 628 static Value genVectorInvariantValue(CodeGen &codegen, 629 PatternRewriter &rewriter, Value val) { 630 VectorType vtp = vectorType(codegen, val.getType()); 631 return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val); 632 } 633 634 /// Generates an affine expression. 635 // 636 // TODO: generalize for sparse tensor subscripts 637 // 638 static Value genAffine(CodeGen &codegen, PatternRewriter &rewriter, 639 AffineExpr a, Location loc) { 640 switch (a.getKind()) { 641 case AffineExprKind::DimId: { 642 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 643 return codegen.loops[idx]; // universal dense index 644 } 645 case AffineExprKind::Add: { 646 auto binOp = a.cast<AffineBinaryOpExpr>(); 647 return rewriter.create<arith::AddIOp>( 648 loc, genAffine(codegen, rewriter, binOp.getLHS(), loc), 649 genAffine(codegen, rewriter, binOp.getRHS(), loc)); 650 } 651 case AffineExprKind::Mul: { 652 auto binOp = a.cast<AffineBinaryOpExpr>(); 653 return rewriter.create<arith::MulIOp>( 654 loc, genAffine(codegen, rewriter, binOp.getLHS(), loc), 655 genAffine(codegen, rewriter, binOp.getRHS(), loc)); 656 } 657 case AffineExprKind::Constant: { 658 int64_t c = a.cast<AffineConstantExpr>().getValue(); 659 return constantIndex(rewriter, loc, c); 660 } 661 default: 662 llvm_unreachable("unexpected affine subscript"); 663 } 664 } 665 666 /// Generates index for load/store on sparse tensor. 667 static Value genIndex(CodeGen &codegen, linalg::GenericOp op, OpOperand *t) { 668 auto map = op.getTiedIndexingMap(t); 669 auto enc = getSparseTensorEncoding(t->get().getType()); 670 AffineExpr a = map.getResult(perm(enc, map.getNumResults() - 1)); 671 assert(a.getKind() == AffineExprKind::DimId); 672 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 673 return codegen.loops[idx]; 674 } 675 676 /// Generates subscript for load/store on a dense or sparse tensor. 677 static Value genSubscript(CodeGen &codegen, PatternRewriter &rewriter, 678 linalg::GenericOp op, OpOperand *t, 679 SmallVector<Value, 4> &args) { 680 unsigned tensor = t->getOperandNumber(); 681 auto map = op.getTiedIndexingMap(t); 682 auto enc = getSparseTensorEncoding(t->get().getType()); 683 unsigned rank = map.getNumResults(); 684 if (enc) { 685 // Note that currently, all sparse subscripts are simple. 686 // TODO: accept affine too? 687 AffineExpr a = map.getResult(perm(enc, rank - 1)); 688 assert(a.getKind() == AffineExprKind::DimId); 689 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 690 assert(codegen.pidxs[tensor][idx] != nullptr); 691 args.push_back(codegen.pidxs[tensor][idx]); // position index 692 } else { 693 for (unsigned d = 0; d < rank; d++) { 694 AffineExpr a = map.getResult(perm(enc, d)); 695 args.push_back(genAffine(codegen, rewriter, a, op.getLoc())); 696 } 697 } 698 return codegen.buffers[tensor]; 699 } 700 701 /// Generates insertion code to implement dynamic tensor load. 702 static Value genInsertionLoad(CodeGen &codegen, PatternRewriter &rewriter, 703 linalg::GenericOp op, OpOperand *t) { 704 Location loc = op.getLoc(); 705 // Direct lexicographic index order, tensor loads as zero. 706 if (!codegen.expValues) { 707 Type tp = getElementTypeOrSelf(t->get().getType()); 708 return constantZero(rewriter, loc, tp); 709 } 710 // Load from expanded access pattern. 711 Value index = genIndex(codegen, op, t); 712 return rewriter.create<memref::LoadOp>(loc, codegen.expValues, index); 713 } 714 715 /// Generates insertion code to implement dynamic tensor store. 716 static void genInsertionStore(CodeGen &codegen, PatternRewriter &rewriter, 717 linalg::GenericOp op, OpOperand *t, Value rhs) { 718 Location loc = op.getLoc(); 719 // Direct insertion in lexicographic index order. 720 if (!codegen.expValues) { 721 rewriter.create<LexInsertOp>(loc, t->get(), codegen.lexIdx, rhs); 722 return; 723 } 724 // Generates insertion code along expanded access pattern. 725 // if (!expFilled[i]) then 726 // expFilled[i] = true 727 // expAdded[inserts++] = i 728 // endif 729 // values[i] = rhs 730 Value index = genIndex(codegen, op, t); 731 Value fval = constantI1(rewriter, loc, false); 732 Value tval = constantI1(rewriter, loc, true); 733 // If statement. 734 Value filled = rewriter.create<memref::LoadOp>(loc, codegen.expFilled, index); 735 Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 736 filled, fval); 737 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, rewriter.getIndexType(), 738 cond, /*else=*/true); 739 // True branch. 740 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); 741 rewriter.create<memref::StoreOp>(loc, tval, codegen.expFilled, index); 742 rewriter.create<memref::StoreOp>(loc, index, codegen.expAdded, 743 codegen.expCount); 744 Value one = constantIndex(rewriter, loc, 1); 745 Value add = rewriter.create<arith::AddIOp>(loc, codegen.expCount, one); 746 rewriter.create<scf::YieldOp>(loc, add); 747 // False branch. 748 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); 749 rewriter.create<scf::YieldOp>(loc, codegen.expCount); 750 rewriter.setInsertionPointAfter(ifOp); 751 // Value assignment. 752 codegen.expCount = ifOp.getResult(0); 753 rewriter.create<memref::StoreOp>(loc, rhs, codegen.expValues, index); 754 } 755 756 /// Generates a load on a dense or sparse tensor. 757 static Value genTensorLoad(Merger &merger, CodeGen &codegen, 758 PatternRewriter &rewriter, linalg::GenericOp op, 759 unsigned exp) { 760 // Test if the load was hoisted to a higher loop nest. 761 Value val = merger.exp(exp).val; 762 if (val) { 763 if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>()) 764 return genVectorInvariantValue(codegen, rewriter, val); 765 return val; 766 } 767 // Load during insertion. 768 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 769 if (t == codegen.sparseOut) 770 return genInsertionLoad(codegen, rewriter, op, t); 771 // Actual load. 772 SmallVector<Value, 4> args; 773 Value ptr = genSubscript(codegen, rewriter, op, t, args); 774 if (codegen.curVecLength > 1) 775 return genVectorLoad(codegen, rewriter, ptr, args); 776 return rewriter.create<memref::LoadOp>(op.getLoc(), ptr, args); 777 } 778 779 /// Generates a store on a dense or sparse tensor. 780 static void genTensorStore(Merger &merger, CodeGen &codegen, 781 PatternRewriter &rewriter, linalg::GenericOp op, 782 Value rhs) { 783 Location loc = op.getLoc(); 784 // Test if this is a scalarized reduction. 785 if (codegen.redVal) { 786 if (codegen.curVecLength > 1) 787 rhs = rewriter.create<arith::SelectOp>(loc, codegen.curVecMask, rhs, 788 codegen.redVal); 789 updateReduc(merger, codegen, rhs); 790 return; 791 } 792 // Store during insertion. 793 OpOperand *t = op.getOutputOperand(0); 794 if (t == codegen.sparseOut) { 795 genInsertionStore(codegen, rewriter, op, t, rhs); 796 return; 797 } 798 // Actual store. 799 SmallVector<Value, 4> args; 800 Value ptr = genSubscript(codegen, rewriter, op, t, args); 801 if (codegen.curVecLength > 1) 802 genVectorStore(codegen, rewriter, rhs, ptr, args); 803 else 804 rewriter.create<memref::StoreOp>(loc, rhs, ptr, args); 805 } 806 807 /// Generates a pointer/index load from the sparse storage scheme. Narrower 808 /// data types need to be zero extended before casting the value into the 809 /// index type used for looping and indexing. 810 static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc, 811 Value ptr, Value s) { 812 // See https://llvm.org/docs/GetElementPtr.html for some background on 813 // the complications described below. 814 if (codegen.curVecLength > 1) { 815 // Since the index vector is used in a subsequent gather/scatter operations, 816 // which effectively defines an unsigned pointer + signed index, we must 817 // zero extend the vector to an index width. For 8-bit and 16-bit values, 818 // an 32-bit index width suffices. For 32-bit values, zero extending the 819 // elements into 64-bit loses some performance since the 32-bit indexed 820 // gather/scatter is more efficient than the 64-bit index variant (if the 821 // negative 32-bit index space is unused, the enableSIMDIndex32 flag can 822 // preserve this performance). For 64-bit values, there is no good way 823 // to state that the indices are unsigned, with creates the potential of 824 // incorrect address calculations in the unlikely case we need such 825 // extremely large offsets. 826 Type etp = ptr.getType().cast<MemRefType>().getElementType(); 827 Value vload = genVectorLoad(codegen, rewriter, ptr, {s}); 828 if (!etp.isa<IndexType>()) { 829 if (etp.getIntOrFloatBitWidth() < 32) 830 vload = rewriter.create<arith::ExtUIOp>( 831 loc, vectorType(codegen, rewriter.getI32Type()), vload); 832 else if (etp.getIntOrFloatBitWidth() < 64 && 833 !codegen.options.enableSIMDIndex32) 834 vload = rewriter.create<arith::ExtUIOp>( 835 loc, vectorType(codegen, rewriter.getI64Type()), vload); 836 } 837 return vload; 838 } 839 // For the scalar case, we simply zero extend narrower indices into 64-bit 840 // values before casting to index without a performance penalty. Here too, 841 // however, indices that already are 64-bit, in theory, cannot express the 842 // full range as explained above. 843 Value load = rewriter.create<memref::LoadOp>(loc, ptr, s); 844 if (!load.getType().isa<IndexType>()) { 845 if (load.getType().getIntOrFloatBitWidth() < 64) 846 load = rewriter.create<arith::ExtUIOp>(loc, rewriter.getI64Type(), load); 847 load = 848 rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), load); 849 } 850 return load; 851 } 852 853 /// Generates an invariant value. 854 static Value genInvariantValue(Merger &merger, CodeGen &codegen, 855 PatternRewriter &rewriter, unsigned exp) { 856 Value val = merger.exp(exp).val; 857 if (codegen.curVecLength > 1) 858 return genVectorInvariantValue(codegen, rewriter, val); 859 return val; 860 } 861 862 /// Generates an address computation "sz * p + i". 863 static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter, 864 Location loc, Value size, Value p, Value i) { 865 Value mul = rewriter.create<arith::MulIOp>(loc, size, p); 866 if (auto vtp = i.getType().dyn_cast<VectorType>()) { 867 Value inv = 868 rewriter.create<arith::IndexCastOp>(loc, vtp.getElementType(), mul); 869 mul = genVectorInvariantValue(codegen, rewriter, inv); 870 } 871 return rewriter.create<arith::AddIOp>(loc, mul, i); 872 } 873 874 /// Generates an index value. 875 static Value genIndexValue(Merger &merger, CodeGen &codegen, unsigned exp) { 876 assert(codegen.curVecLength == 1); // TODO: implement vectorization! 877 unsigned idx = merger.exp(exp).index; 878 return codegen.loops[idx]; 879 } 880 881 /// Recursively generates tensor expression. 882 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 883 linalg::GenericOp op, unsigned exp) { 884 Location loc = op.getLoc(); 885 if (exp == -1u) 886 return Value(); 887 if (merger.exp(exp).kind == Kind::kTensor) 888 return genTensorLoad(merger, codegen, rewriter, op, exp); 889 if (merger.exp(exp).kind == Kind::kInvariant) 890 return genInvariantValue(merger, codegen, rewriter, exp); 891 if (merger.exp(exp).kind == Kind::kIndex) 892 return genIndexValue(merger, codegen, exp); 893 Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0); 894 Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1); 895 return merger.buildExp(rewriter, loc, exp, v0, v1); 896 } 897 898 /// Determines if affine expression is invariant. 899 static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a, 900 unsigned ldx, bool &atLevel) { 901 switch (a.getKind()) { 902 case AffineExprKind::DimId: { 903 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 904 if (idx == ldx) 905 atLevel = true; 906 return codegen.loops[idx] != nullptr; // no longer in play? 907 } 908 case AffineExprKind::Add: 909 case AffineExprKind::Mul: { 910 auto binOp = a.cast<AffineBinaryOpExpr>(); 911 return isInvariantAffine(codegen, binOp.getLHS(), ldx, atLevel) && 912 isInvariantAffine(codegen, binOp.getRHS(), ldx, atLevel); 913 } 914 default: 915 return true; 916 } 917 } 918 919 /// Hoists loop invariant tensor loads for which indices have been exhausted. 920 static void genInvariants(Merger &merger, CodeGen &codegen, 921 PatternRewriter &rewriter, linalg::GenericOp op, 922 unsigned exp, unsigned ldx, bool atStart, 923 Kind last = Kind::kTensor) { 924 if (exp == -1u) 925 return; 926 if (merger.exp(exp).kind == Kind::kTensor) { 927 // Inspect tensor indices. 928 bool atLevel = ldx == -1u; 929 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 930 auto map = op.getTiedIndexingMap(t); 931 auto enc = getSparseTensorEncoding(t->get().getType()); 932 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 933 AffineExpr a = map.getResult(perm(enc, d)); 934 if (!isInvariantAffine(codegen, a, ldx, atLevel)) 935 return; // still in play 936 } 937 // All exhausted at this level (atLevel denotes exactly at this level). 938 if (!atLevel) 939 return; 940 OpOperand *lhs = op.getOutputOperand(0); 941 if (lhs == t) { 942 // Start or end a scalarized reduction 943 if (atStart) { 944 Value load = genTensorLoad(merger, codegen, rewriter, op, exp); 945 codegen.redKind = getReduction(last); 946 codegen.redExp = exp; 947 updateReduc(merger, codegen, load); 948 } else { 949 Value redVal = codegen.redVal; 950 updateReduc(merger, codegen, Value()); 951 codegen.redExp = -1u; 952 codegen.redKind = kNoReduc; 953 genTensorStore(merger, codegen, rewriter, op, redVal); 954 } 955 } else { 956 // Start or end loop invariant hoisting of a tensor load. 957 merger.exp(exp).val = 958 atStart ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value(); 959 } 960 } else if (merger.exp(exp).kind != Kind::kInvariant && 961 merger.exp(exp).kind != Kind::kIndex) { 962 // Traverse into the binary operations. Note that we only hoist 963 // tensor loads, since subsequent MLIR/LLVM passes know how to 964 // deal with all other kinds of derived loop invariants. 965 Kind last = merger.exp(exp).kind; 966 unsigned e0 = merger.exp(exp).children.e0; 967 unsigned e1 = merger.exp(exp).children.e1; 968 genInvariants(merger, codegen, rewriter, op, e0, ldx, atStart, last); 969 genInvariants(merger, codegen, rewriter, op, e1, ldx, atStart, last); 970 } 971 } 972 973 /// Generates an expanded access pattern in innermost dimension. 974 static void genExpansion(Merger &merger, CodeGen &codegen, 975 PatternRewriter &rewriter, linalg::GenericOp op, 976 unsigned at, bool atStart) { 977 OpOperand *lhs = codegen.sparseOut; 978 if (!lhs || codegen.outerParNest != op.getRank(lhs) - 1 || 979 at != codegen.outerParNest) 980 return; // not needed at this level 981 // Generate start or end of an expanded access pattern. 982 Value tensor = lhs->get(); 983 Location loc = op.getLoc(); 984 if (atStart) { 985 auto dynShape = {ShapedType::kDynamicSize}; 986 Type etp = tensor.getType().cast<ShapedType>().getElementType(); 987 Type t1 = MemRefType::get(dynShape, etp); 988 Type t2 = MemRefType::get(dynShape, rewriter.getI1Type()); 989 Type t3 = MemRefType::get(dynShape, rewriter.getIndexType()); 990 Type t4 = rewriter.getIndexType(); 991 auto res = 992 rewriter.create<ExpandOp>(loc, TypeRange({t1, t2, t3, t4}), tensor); 993 assert(res.getNumResults() == 4); 994 assert(!codegen.expValues); 995 codegen.expValues = res.getResult(0); 996 codegen.expFilled = res.getResult(1); 997 codegen.expAdded = res.getResult(2); 998 codegen.expCount = res.getResult(3); 999 } else { 1000 assert(codegen.expValues); 1001 rewriter.create<CompressOp>(loc, tensor, codegen.lexIdx, codegen.expValues, 1002 codegen.expFilled, codegen.expAdded, 1003 codegen.expCount); 1004 codegen.expValues = codegen.expFilled = codegen.expAdded = 1005 codegen.expCount = Value(); 1006 } 1007 } 1008 1009 /// Generates initialization code for the subsequent loop sequence at 1010 /// current index level. Returns true if the loop sequence needs to 1011 /// maintain the universal index. 1012 static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1013 linalg::GenericOp op, std::vector<unsigned> &topSort, 1014 unsigned at, BitVector &inits) { 1015 bool needsUniv = false; 1016 Location loc = op.getLoc(); 1017 unsigned idx = topSort[at]; 1018 1019 // Initialize sparse positions. 1020 for (unsigned b = 0, be = inits.size(); b < be; b++) { 1021 if (inits[b]) { 1022 unsigned tensor = merger.tensor(b); 1023 assert(idx == merger.index(b)); 1024 if (merger.isDim(b, Dim::kSparse)) { 1025 // Initialize sparse index. 1026 unsigned pat = at; 1027 for (; pat != 0; pat--) { 1028 if (codegen.pidxs[tensor][topSort[pat - 1]]) 1029 break; 1030 } 1031 Value ptr = codegen.pointers[tensor][idx]; 1032 Value one = constantIndex(rewriter, loc, 1); 1033 Value p0 = (pat == 0) ? constantIndex(rewriter, loc, 0) 1034 : codegen.pidxs[tensor][topSort[pat - 1]]; 1035 codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0); 1036 Value p1 = rewriter.create<arith::AddIOp>(loc, p0, one); 1037 codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1); 1038 } else { 1039 // Dense index still in play. 1040 needsUniv = true; 1041 } 1042 } 1043 } 1044 1045 // Initialize the universal dense index. 1046 codegen.loops[idx] = constantIndex(rewriter, loc, 0); 1047 return needsUniv; 1048 } 1049 1050 /// Returns vectorization strategy. Any implicit inner loop in the Linalg 1051 /// operation is a candidate. Whether it is actually converted to SIMD code 1052 /// depends on the requested strategy. 1053 static bool isVectorFor(CodeGen &codegen, bool isInner, bool isReduction, 1054 bool isSparse) { 1055 // Reject vectorization of sparse output, unless innermost is reduction. 1056 if (codegen.sparseOut && !isReduction) 1057 return false; 1058 // Inspect strategy. 1059 switch (codegen.options.vectorizationStrategy) { 1060 case SparseVectorizationStrategy::kNone: 1061 return false; 1062 case SparseVectorizationStrategy::kDenseInnerLoop: 1063 return isInner && !isSparse; 1064 case SparseVectorizationStrategy::kAnyStorageInnerLoop: 1065 return isInner; 1066 } 1067 llvm_unreachable("unexpected vectorization strategy"); 1068 } 1069 1070 /// Returns parallelization strategy. Any implicit loop in the Linalg operation 1071 /// that is marked "parallel" is a candidate. Whether it is actually converted 1072 /// to a parallel operation depends on the requested strategy. 1073 static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction, 1074 bool isSparse, bool isVector) { 1075 // Reject parallelization of sparse output. 1076 if (codegen.sparseOut) 1077 return false; 1078 // Inspect strategy. 1079 switch (codegen.options.parallelizationStrategy) { 1080 case SparseParallelizationStrategy::kNone: 1081 return false; 1082 case SparseParallelizationStrategy::kDenseOuterLoop: 1083 return isOuter && !isSparse && !isReduction && !isVector; 1084 case SparseParallelizationStrategy::kAnyStorageOuterLoop: 1085 return isOuter && !isReduction && !isVector; 1086 case SparseParallelizationStrategy::kDenseAnyLoop: 1087 return !isSparse && !isReduction && !isVector; 1088 case SparseParallelizationStrategy::kAnyStorageAnyLoop: 1089 return !isReduction && !isVector; 1090 } 1091 llvm_unreachable("unexpected parallelization strategy"); 1092 } 1093 1094 /// Checks unit stride for dense tensors. The iteration graph may have ignored 1095 /// dense access patterns in order to avoid cycles (sparse access patterns are 1096 /// always placed innermost), but that means dense access has become strided. 1097 /// This prevents effective vectorization. 1098 static bool denseUnitStrides(Merger &merger, linalg::GenericOp op, 1099 unsigned idx) { 1100 for (OpOperand *t : op.getInputAndOutputOperands()) { 1101 if (!getSparseTensorEncoding(t->get().getType())) { 1102 auto map = op.getTiedIndexingMap(t); 1103 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 1104 AffineExpr a = map.getResult(d); 1105 // Report non-unit stride if innermost index appears at an outer 1106 // dimension (true non-unit stride) or if the innermost index appears 1107 // in a compound subscript in the innermost dimension. Even if the 1108 // latter is unit stride, it does not play well with scatter/gather. 1109 // TODO: accept unit stride affine innermost like a[i,j+k+1]? 1110 if (a.isFunctionOfDim(idx) && 1111 ((d != rank - 1) || (a.getKind() != AffineExprKind::DimId))) 1112 return false; 1113 } 1114 } 1115 } 1116 return true; 1117 } 1118 1119 /// Generates a for-loop on a single index. 1120 static Operation *genFor(Merger &merger, CodeGen &codegen, 1121 PatternRewriter &rewriter, linalg::GenericOp op, 1122 bool isOuter, bool isInner, unsigned idx, 1123 BitVector &indices) { 1124 unsigned fb = indices.find_first(); 1125 unsigned tensor = merger.tensor(fb); 1126 assert(idx == merger.index(fb)); 1127 auto iteratorTypes = op.iterator_types().getValue(); 1128 bool isReduction = isReductionIterator(iteratorTypes[idx]); 1129 bool isSparse = merger.isDim(fb, Dim::kSparse); 1130 bool isVector = isVectorFor(codegen, isInner, isReduction, isSparse) && 1131 denseUnitStrides(merger, op, idx); 1132 bool isParallel = 1133 isParallelFor(codegen, isOuter, isReduction, isSparse, isVector); 1134 1135 // Prepare vector length. 1136 if (isVector) 1137 codegen.curVecLength = codegen.options.vectorLength; 1138 1139 // Loop bounds and increment. 1140 Location loc = op.getLoc(); 1141 Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx]; 1142 Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx]; 1143 Value step = constantIndex(rewriter, loc, codegen.curVecLength); 1144 1145 // Emit a parallel loop. 1146 if (isParallel) { 1147 assert(!isVector); 1148 scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step); 1149 if (isSparse) 1150 codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0]; 1151 else 1152 codegen.loops[idx] = parOp.getInductionVars()[0]; 1153 rewriter.setInsertionPointToStart(parOp.getBody()); 1154 return parOp; 1155 } 1156 1157 // Emit a sequential or vector loop. 1158 SmallVector<Value, 4> operands; 1159 if (codegen.redVal) { 1160 // In a vector loop, bring reduction into SIMD form, if not already. 1161 if (isVector && !codegen.redVal.getType().isa<VectorType>()) { 1162 VectorType vtp = vectorType(codegen, codegen.redVal.getType()); 1163 Value vred = genVectorReducInit(codegen, rewriter, loc, vtp); 1164 updateReduc(merger, codegen, vred); 1165 } 1166 operands.push_back(codegen.redVal); 1167 } 1168 if (codegen.expValues) 1169 operands.push_back(codegen.expCount); 1170 scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands); 1171 if (codegen.redVal) 1172 updateReduc(merger, codegen, forOp.getRegionIterArgs().front()); 1173 if (codegen.expValues) 1174 codegen.expCount = forOp.getRegionIterArgs().back(); 1175 // Assign induction variable to sparse or dense index. 1176 Value iv = forOp.getInductionVar(); 1177 if (isSparse) 1178 codegen.pidxs[tensor][idx] = iv; 1179 else 1180 codegen.loops[idx] = iv; 1181 rewriter.setInsertionPointToStart(forOp.getBody()); 1182 // Share vector iteration mask between all subsequent loads/stores. 1183 if (isVector) 1184 codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step); 1185 return forOp; 1186 } 1187 1188 /// Emit a while-loop for co-iteration over multiple indices. 1189 static Operation *genWhile(Merger &merger, CodeGen &codegen, 1190 PatternRewriter &rewriter, linalg::GenericOp op, 1191 unsigned idx, bool needsUniv, 1192 BitVector &indices) { 1193 SmallVector<Type, 4> types; 1194 SmallVector<Value, 4> operands; 1195 // Construct the while-loop with a parameter for each index. 1196 Type indexType = rewriter.getIndexType(); 1197 for (unsigned b = 0, be = indices.size(); b < be; b++) { 1198 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1199 unsigned tensor = merger.tensor(b); 1200 assert(idx == merger.index(b)); 1201 types.push_back(indexType); 1202 operands.push_back(codegen.pidxs[tensor][idx]); 1203 } 1204 } 1205 if (codegen.redVal) { 1206 types.push_back(codegen.redVal.getType()); 1207 operands.push_back(codegen.redVal); 1208 } 1209 if (codegen.expValues) { 1210 types.push_back(indexType); 1211 operands.push_back(codegen.expCount); 1212 } 1213 if (needsUniv) { 1214 types.push_back(indexType); 1215 operands.push_back(codegen.loops[idx]); 1216 } 1217 assert(types.size() == operands.size()); 1218 Location loc = op.getLoc(); 1219 scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands); 1220 1221 SmallVector<Location> locs(types.size(), loc); 1222 Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, locs); 1223 Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, types, locs); 1224 1225 // Build the "before" region, which effectively consists 1226 // of a conjunction of "i < upper" tests on all induction. 1227 rewriter.setInsertionPointToStart(&whileOp.getBefore().front()); 1228 Value cond; 1229 unsigned o = 0; 1230 for (unsigned b = 0, be = indices.size(); b < be; b++) { 1231 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1232 unsigned tensor = merger.tensor(b); 1233 assert(idx == merger.index(b)); 1234 Value op1 = before->getArgument(o); 1235 Value op2 = codegen.highs[tensor][idx]; 1236 Value opc = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 1237 op1, op2); 1238 cond = cond ? rewriter.create<arith::AndIOp>(loc, cond, opc) : opc; 1239 codegen.pidxs[tensor][idx] = after->getArgument(o++); 1240 } 1241 } 1242 if (codegen.redVal) 1243 updateReduc(merger, codegen, after->getArgument(o++)); 1244 if (codegen.expValues) 1245 codegen.expCount = after->getArgument(o++); 1246 if (needsUniv) 1247 codegen.loops[idx] = after->getArgument(o++); 1248 assert(o == operands.size()); 1249 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 1250 rewriter.setInsertionPointToStart(&whileOp.getAfter().front()); 1251 return whileOp; 1252 } 1253 1254 /// Generates a for-loop or a while-loop, depending on whether it implements 1255 /// singleton iteration or co-iteration over the given conjunction. 1256 static Operation *genLoop(Merger &merger, CodeGen &codegen, 1257 PatternRewriter &rewriter, linalg::GenericOp op, 1258 std::vector<unsigned> &topSort, unsigned at, 1259 bool needsUniv, BitVector &indices) { 1260 unsigned idx = topSort[at]; 1261 if (indices.count() == 1) { 1262 bool isOuter = at == 0; 1263 bool isInner = at == topSort.size() - 1; 1264 return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx, 1265 indices); 1266 } 1267 return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices); 1268 } 1269 1270 /// Generates the local variables for this loop, consisting of the sparse 1271 /// indices, restored universal dense index, and dense positions. 1272 static void genLocals(Merger &merger, CodeGen &codegen, 1273 PatternRewriter &rewriter, linalg::GenericOp op, 1274 std::vector<unsigned> &topSort, unsigned at, 1275 bool needsUniv, BitVector &locals) { 1276 Location loc = op.getLoc(); 1277 unsigned idx = topSort[at]; 1278 1279 // Initialize sparse indices. 1280 Value min; 1281 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1282 if (locals[b] && merger.isDim(b, Dim::kSparse)) { 1283 unsigned tensor = merger.tensor(b); 1284 assert(idx == merger.index(b)); 1285 Value ptr = codegen.indices[tensor][idx]; 1286 Value s = codegen.pidxs[tensor][idx]; 1287 Value load = genLoad(codegen, rewriter, loc, ptr, s); 1288 codegen.idxs[tensor][idx] = load; 1289 if (!needsUniv) { 1290 if (min) { 1291 Value cmp = rewriter.create<arith::CmpIOp>( 1292 loc, arith::CmpIPredicate::ult, load, min); 1293 min = rewriter.create<arith::SelectOp>(loc, cmp, load, min); 1294 } else { 1295 min = load; 1296 } 1297 } 1298 } 1299 } 1300 1301 // Merge dense universal index over minimum. 1302 if (min) { 1303 assert(!needsUniv); 1304 codegen.loops[idx] = min; 1305 } 1306 1307 // Initialize dense positions. Note that we generate dense indices of the 1308 // output tensor unconditionally, since they may not appear in the lattice, 1309 // but may be needed for linearized codegen. 1310 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1311 if ((locals[b] || merger.isOutTensor(b, idx)) && 1312 merger.isDim(b, Dim::kDense)) { 1313 unsigned tensor = merger.tensor(b); 1314 assert(idx == merger.index(b)); 1315 unsigned pat = at; 1316 for (; pat != 0; pat--) 1317 if (codegen.pidxs[tensor][topSort[pat - 1]]) 1318 break; 1319 Value p = (pat == 0) ? constantIndex(rewriter, loc, 0) 1320 : codegen.pidxs[tensor][topSort[pat - 1]]; 1321 codegen.pidxs[tensor][idx] = genAddress( 1322 codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]); 1323 } 1324 } 1325 1326 // Move the insertion indices in lexicographic index order. During access 1327 // pattern expansion, we can skip setting the innermost dimension. 1328 if (codegen.sparseOut && !codegen.expValues) { 1329 Value pos = constantIndex(rewriter, loc, at); 1330 rewriter.create<memref::StoreOp>(loc, codegen.loops[idx], codegen.lexIdx, 1331 pos); 1332 } 1333 } 1334 1335 /// Generates the induction structure for a while-loop. 1336 static void genWhileInduction(Merger &merger, CodeGen &codegen, 1337 PatternRewriter &rewriter, linalg::GenericOp op, 1338 unsigned idx, bool needsUniv, 1339 BitVector &induction, 1340 scf::WhileOp whileOp) { 1341 Location loc = op.getLoc(); 1342 // Finalize each else branch of all if statements. 1343 if (codegen.redVal || codegen.expValues) { 1344 while (auto ifOp = dyn_cast_or_null<scf::IfOp>( 1345 rewriter.getInsertionBlock()->getParentOp())) { 1346 unsigned y = 0; 1347 SmallVector<Value, 4> yields; 1348 if (codegen.redVal) { 1349 yields.push_back(codegen.redVal); 1350 updateReduc(merger, codegen, ifOp.getResult(y++)); 1351 } 1352 if (codegen.expValues) { 1353 yields.push_back(codegen.expCount); 1354 codegen.expCount = ifOp->getResult(y++); 1355 } 1356 assert(y == yields.size()); 1357 rewriter.create<scf::YieldOp>(loc, yields); 1358 rewriter.setInsertionPointAfter(ifOp); 1359 } 1360 } 1361 rewriter.setInsertionPointToEnd(&whileOp.getAfter().front()); 1362 // Finalize the induction. Note that the induction could be performed 1363 // in the individual if-branches to avoid re-evaluating the conditions. 1364 // However, that would result in a rather elaborate forest of yield 1365 // instructions during code generation. Moreover, performing the induction 1366 // after the if-statements more closely resembles code generated by TACO. 1367 unsigned o = 0; 1368 SmallVector<Value, 4> operands; 1369 Value one = constantIndex(rewriter, loc, 1); 1370 for (unsigned b = 0, be = induction.size(); b < be; b++) { 1371 if (induction[b] && merger.isDim(b, Dim::kSparse)) { 1372 unsigned tensor = merger.tensor(b); 1373 assert(idx == merger.index(b)); 1374 Value op1 = codegen.idxs[tensor][idx]; 1375 Value op2 = codegen.loops[idx]; 1376 Value op3 = codegen.pidxs[tensor][idx]; 1377 Value cmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 1378 op1, op2); 1379 Value add = rewriter.create<arith::AddIOp>(loc, op3, one); 1380 operands.push_back(rewriter.create<arith::SelectOp>(loc, cmp, add, op3)); 1381 codegen.pidxs[tensor][idx] = whileOp->getResult(o++); 1382 } 1383 } 1384 if (codegen.redVal) { 1385 operands.push_back(codegen.redVal); 1386 updateReduc(merger, codegen, whileOp->getResult(o++)); 1387 } 1388 if (codegen.expValues) { 1389 operands.push_back(codegen.expCount); 1390 codegen.expCount = whileOp->getResult(o++); 1391 } 1392 if (needsUniv) { 1393 operands.push_back( 1394 rewriter.create<arith::AddIOp>(loc, codegen.loops[idx], one)); 1395 codegen.loops[idx] = whileOp->getResult(o++); 1396 } 1397 assert(o == operands.size()); 1398 rewriter.create<scf::YieldOp>(loc, operands); 1399 rewriter.setInsertionPointAfter(whileOp); 1400 } 1401 1402 /// Generates the induction structure for a for-loop. 1403 static void genForInduction(Merger &merger, CodeGen &codegen, 1404 PatternRewriter &rewriter, linalg::GenericOp op, 1405 Operation *loop) { 1406 Location loc = op.getLoc(); 1407 unsigned o = 0; 1408 SmallVector<Value, 4> operands; 1409 if (codegen.redVal) { 1410 operands.push_back(codegen.redVal); 1411 updateReduc(merger, codegen, loop->getResult(o++)); 1412 } 1413 if (codegen.expValues) { 1414 operands.push_back(codegen.expCount); 1415 codegen.expCount = loop->getResult(o++); 1416 } 1417 assert(o == operands.size()); 1418 if (o > 0) 1419 rewriter.create<scf::YieldOp>(loc, operands); 1420 rewriter.setInsertionPointAfter(loop); 1421 } 1422 1423 /// Generates a single if-statement within a while-loop. 1424 static scf::IfOp genIf(Merger &merger, CodeGen &codegen, 1425 PatternRewriter &rewriter, linalg::GenericOp op, 1426 unsigned idx, BitVector &conditions) { 1427 Location loc = op.getLoc(); 1428 SmallVector<Type, 4> types; 1429 Value cond; 1430 for (unsigned b = 0, be = conditions.size(); b < be; b++) { 1431 if (conditions[b]) { 1432 unsigned tensor = merger.tensor(b); 1433 assert(idx == merger.index(b)); 1434 Value clause; 1435 if (merger.isDim(b, Dim::kSparse)) { 1436 Value op1 = codegen.idxs[tensor][idx]; 1437 Value op2 = codegen.loops[idx]; 1438 clause = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 1439 op1, op2); 1440 } else { 1441 clause = constantI1(rewriter, loc, true); 1442 } 1443 cond = cond ? rewriter.create<arith::AndIOp>(loc, cond, clause) : clause; 1444 } 1445 } 1446 if (codegen.redVal) 1447 types.push_back(codegen.redVal.getType()); 1448 if (codegen.expValues) 1449 types.push_back(rewriter.getIndexType()); 1450 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, types, cond, /*else=*/true); 1451 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); 1452 return ifOp; 1453 } 1454 1455 /// Generates end of true branch of if-statement within a while-loop. 1456 static void endIf(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1457 linalg::GenericOp op, scf::IfOp ifOp, Operation *loop, 1458 Value redInput, Value cntInput) { 1459 SmallVector<Value, 4> operands; 1460 if (codegen.redVal) { 1461 operands.push_back(codegen.redVal); 1462 updateReduc(merger, codegen, redInput); 1463 } 1464 if (codegen.expValues) { 1465 operands.push_back(codegen.expCount); 1466 codegen.expCount = cntInput; 1467 } 1468 if (!operands.empty()) 1469 rewriter.create<scf::YieldOp>(op.getLoc(), operands); 1470 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); 1471 } 1472 1473 //===----------------------------------------------------------------------===// 1474 // Sparse compiler synthesis methods (loop sequence). 1475 //===----------------------------------------------------------------------===// 1476 1477 /// Starts a loop sequence at given level. Returns true if 1478 /// the universal loop index must be maintained at this level. 1479 static bool startLoopSeq(Merger &merger, CodeGen &codegen, 1480 PatternRewriter &rewriter, linalg::GenericOp op, 1481 std::vector<unsigned> &topSort, unsigned exp, 1482 unsigned at, unsigned idx, unsigned ldx, 1483 unsigned lts) { 1484 assert(codegen.curVecLength == 1); 1485 assert(!codegen.loops[idx]); 1486 // Emit invariants at this loop sequence level. 1487 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*atStart=*/true); 1488 // Emit access pattern expansion for sparse tensor output. 1489 genExpansion(merger, codegen, rewriter, op, at, /*atStart=*/true); 1490 // Emit further intitialization at this loop sequence level. 1491 unsigned l0 = merger.set(lts)[0]; 1492 bool needsUniv = 1493 genInit(merger, codegen, rewriter, op, topSort, at, merger.lat(l0).bits); 1494 // Maintain the universal index only if it is actually 1495 // consumed by a subsequent lattice point. 1496 if (needsUniv) { 1497 unsigned lsize = merger.set(lts).size(); 1498 for (unsigned i = 1; i < lsize; i++) { 1499 unsigned li = merger.set(lts)[i]; 1500 if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) 1501 return true; 1502 } 1503 } 1504 return false; 1505 } 1506 1507 /// Starts a single loop in current sequence. 1508 static Operation *startLoop(Merger &merger, CodeGen &codegen, 1509 PatternRewriter &rewriter, linalg::GenericOp op, 1510 std::vector<unsigned> &topSort, unsigned at, 1511 unsigned li, bool needsUniv) { 1512 assert(codegen.curVecLength == 1); 1513 // Emit the for/while-loop control. 1514 Operation *loop = genLoop(merger, codegen, rewriter, op, topSort, at, 1515 needsUniv, merger.lat(li).simple); 1516 // Emit the locals for this loop. 1517 genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, 1518 merger.lat(li).bits); 1519 return loop; 1520 } 1521 1522 /// Ends a single loop in current sequence. Returns new values for needsUniv. 1523 static bool endLoop(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1524 linalg::GenericOp op, Operation *loop, unsigned idx, 1525 unsigned li, bool needsUniv) { 1526 codegen.curVecLength = 1; 1527 // End a while-loop. 1528 if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) { 1529 genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv, 1530 merger.lat(li).bits, whileOp); 1531 return needsUniv; 1532 } 1533 // End a for-loop. 1534 genForInduction(merger, codegen, rewriter, op, loop); 1535 return false; 1536 } 1537 1538 /// Ends a loop sequence at given level. 1539 static void endLoopSeq(Merger &merger, CodeGen &codegen, 1540 PatternRewriter &rewriter, linalg::GenericOp op, 1541 unsigned exp, unsigned at, unsigned idx, unsigned ldx) { 1542 assert(codegen.curVecLength == 1); 1543 codegen.loops[idx] = Value(); 1544 // Bring a pending reduction back from SIMD form when sequence ends. 1545 if (codegen.redVal) 1546 if (auto vtp = codegen.redVal.getType().dyn_cast<VectorType>()) 1547 updateReduc(merger, codegen, 1548 genVectorReducEnd(codegen, rewriter, op.getLoc(), vtp)); 1549 // Unmark bookkeeping of invariants and loop index. 1550 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*atStart=*/false); 1551 // Finalize access pattern expansion for sparse tensor output. 1552 genExpansion(merger, codegen, rewriter, op, at, /*atStart=*/false); 1553 } 1554 1555 /// Recursively generates code while computing iteration lattices in order 1556 /// to manage the complexity of implementing co-iteration over unions 1557 /// and intersections of sparse iterations spaces. 1558 static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1559 linalg::GenericOp op, std::vector<unsigned> &topSort, 1560 unsigned exp, unsigned at) { 1561 // At each leaf, assign remaining tensor (sub)expression to output tensor. 1562 if (at == topSort.size()) { 1563 Value rhs = genExp(merger, codegen, rewriter, op, exp); 1564 genTensorStore(merger, codegen, rewriter, op, rhs); 1565 return; 1566 } 1567 1568 // Construct iteration lattices for current loop index, with L0 at top. 1569 unsigned idx = topSort[at]; 1570 unsigned ldx = at == 0 ? -1u : topSort[at - 1]; 1571 unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx)); 1572 1573 // Start a loop sequence. 1574 bool needsUniv = startLoopSeq(merger, codegen, rewriter, op, topSort, exp, at, 1575 idx, ldx, lts); 1576 1577 // Emit a loop for every lattice point L0 >= Li in this loop sequence. 1578 unsigned lsize = merger.set(lts).size(); 1579 for (unsigned i = 0; i < lsize; i++) { 1580 // Start a loop. 1581 unsigned li = merger.set(lts)[i]; 1582 Operation *loop = 1583 startLoop(merger, codegen, rewriter, op, topSort, at, li, needsUniv); 1584 1585 // Visit all lattices points with Li >= Lj to generate the 1586 // loop-body, possibly with if statements for coiteration. 1587 Value redInput = codegen.redVal; 1588 Value cntInput = codegen.expCount; 1589 bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr; 1590 for (unsigned j = 0; j < lsize; j++) { 1591 unsigned lj = merger.set(lts)[j]; 1592 unsigned ej = merger.lat(lj).exp; 1593 if (li == lj || merger.latGT(li, lj)) { 1594 // Recurse into body of each branch. 1595 if (isWhile) { 1596 scf::IfOp ifOp = 1597 genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple); 1598 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1599 endIf(merger, codegen, rewriter, op, ifOp, loop, redInput, cntInput); 1600 } else { 1601 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1602 } 1603 } 1604 } 1605 1606 // End a loop. 1607 needsUniv = 1608 endLoop(merger, codegen, rewriter, op, loop, idx, li, needsUniv); 1609 } 1610 1611 // End a loop sequence. 1612 endLoopSeq(merger, codegen, rewriter, op, exp, at, idx, ldx); 1613 } 1614 1615 /// Converts the result computed by the sparse kernel into the required form. 1616 static void genResult(Merger &merger, CodeGen &codegen, 1617 PatternRewriter &rewriter, linalg::GenericOp op) { 1618 OpOperand *lhs = op.getOutputOperand(0); 1619 Type resType = lhs->get().getType(); 1620 if (getSparseTensorEncoding(resType)) { 1621 // The sparse tensor rematerializes from the original sparse tensor's 1622 // underlying sparse storage format. 1623 rewriter.replaceOpWithNewOp<LoadOp>(op, resType, lhs->get(), 1624 codegen.sparseOut == lhs); 1625 } else { 1626 // To rematerialize an non-annotated tensor, simply load it 1627 // from the bufferized value. 1628 Value val = codegen.buffers.back(); // value array 1629 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val); 1630 } 1631 } 1632 1633 //===----------------------------------------------------------------------===// 1634 // Sparse compiler rewriting methods. 1635 //===----------------------------------------------------------------------===// 1636 1637 namespace { 1638 1639 /// Sparse rewriting rule for generic Lingalg operation. 1640 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> { 1641 public: 1642 GenericOpSparsifier(MLIRContext *context, SparsificationOptions o) 1643 : OpRewritePattern<linalg::GenericOp>(context), options(o) {} 1644 1645 LogicalResult matchAndRewrite(linalg::GenericOp op, 1646 PatternRewriter &rewriter) const override { 1647 1648 // Detects sparse annotations and translate the per-dimension sparsity 1649 // information for all tensors to loop indices in the kernel. 1650 assert(op.getNumOutputs() == 1); 1651 unsigned numTensors = op.getNumInputsAndOutputs(); 1652 unsigned numLoops = op.iterator_types().getValue().size(); 1653 Merger merger(numTensors, numLoops); 1654 if (!findSparseAnnotations(merger, op)) 1655 return failure(); 1656 1657 // Computes a topologically sorted iteration graph to ensure 1658 // tensors are visited in natural index order. Fails on cycles. 1659 // This assumes that higher-level passes have already put the 1660 // tensors in each tensor expression in a feasible order. 1661 std::vector<unsigned> topSort; 1662 if (!computeIterationGraph(merger, op, topSort, 1663 SortMask::kIncludeUndef | 1664 SortMask::kIncludeDense) && 1665 !computeIterationGraph(merger, op, topSort, SortMask::kIncludeUndef) && 1666 !computeIterationGraph(merger, op, topSort, SortMask::kIncludeDense) && 1667 !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly)) 1668 return failure(); 1669 1670 // Builds the tensor expression for the Linalg operation in SSA form. 1671 Optional<unsigned> optExp = merger.buildTensorExpFromLinalg(op); 1672 if (!optExp.hasValue()) 1673 return failure(); 1674 unsigned exp = optExp.getValue(); 1675 1676 // Rejects an inadmissable tensor expression. 1677 OpOperand *sparseOut = nullptr; 1678 unsigned outerParNest = 0; 1679 if (!isAdmissableTensorExp(merger, op, topSort, exp, &sparseOut, 1680 outerParNest)) 1681 return failure(); 1682 1683 // Recursively generates code. 1684 merger.setHasSparseOut(sparseOut != nullptr); 1685 CodeGen codegen(options, numTensors, numLoops, sparseOut, outerParNest); 1686 genBuffers(merger, codegen, rewriter, op); 1687 genStmt(merger, codegen, rewriter, op, topSort, exp, 0); 1688 genResult(merger, codegen, rewriter, op); 1689 return success(); 1690 } 1691 1692 private: 1693 /// Options to control sparse code generation. 1694 SparsificationOptions options; 1695 }; 1696 1697 } // namespace 1698 1699 /// Populates the given patterns list with rewriting rules required for 1700 /// the sparsification of linear algebra operations. 1701 void mlir::populateSparsificationPatterns( 1702 RewritePatternSet &patterns, const SparsificationOptions &options) { 1703 patterns.add<GenericOpSparsifier>(patterns.getContext(), options); 1704 } 1705