1 //===- Sparsification.cpp - Implementation of sparsification --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements converting sparse tensor types to actual sparse code. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "CodegenUtils.h" 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 17 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 18 #include "mlir/Dialect/Linalg/IR/Linalg.h" 19 #include "mlir/Dialect/Linalg/Utils/Utils.h" 20 #include "mlir/Dialect/MemRef/IR/MemRef.h" 21 #include "mlir/Dialect/SCF/SCF.h" 22 #include "mlir/Dialect/SCF/Transforms.h" 23 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 24 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 25 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 26 #include "mlir/Dialect/Vector/IR/VectorOps.h" 27 #include "mlir/IR/Matchers.h" 28 #include "mlir/IR/TensorEncoding.h" 29 #include "llvm/ADT/SmallBitVector.h" 30 31 using namespace mlir; 32 using namespace mlir::sparse_tensor; 33 34 //===----------------------------------------------------------------------===// 35 // Declarations of data structures. 36 //===----------------------------------------------------------------------===// 37 38 namespace { 39 40 // Iteration graph sorting. 41 enum SortMask { kSparseOnly = 0x0, kIncludeDense = 0x1, kIncludeUndef = 0x2 }; 42 43 // Reduction kinds. 44 enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor }; 45 46 // Code generation. 47 struct CodeGen { 48 CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops, 49 OpOperand *op, unsigned nest) 50 : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors), 51 pointers(numTensors, std::vector<Value>(numLoops)), 52 indices(numTensors, std::vector<Value>(numLoops)), 53 highs(numTensors, std::vector<Value>(numLoops)), 54 pidxs(numTensors, std::vector<Value>(numLoops)), 55 idxs(numTensors, std::vector<Value>(numLoops)), redVal(), sparseOut(op), 56 outerParNest(nest), lexIdx(), expValues(), expFilled(), expAdded(), 57 expCount(), curVecMask() {} 58 /// Sparsification options. 59 SparsificationOptions options; 60 /// Universal dense indices and upper bounds (by index). The loops array 61 /// is updated with the value of the universal dense index in the current 62 /// loop. The sizes array is set once with the inferred dimension sizes. 63 std::vector<Value> loops; 64 std::vector<Value> sizes; 65 /// Buffers for storing dense and sparse numerical values (by tensor). 66 /// This array is set once during bufferization of all tensors. 67 std::vector<Value> buffers; 68 /// Sparse storage schemes (1-D): pointers and indices (by tensor and index). 69 /// This array is set once during bufferization of all sparse tensors. 70 std::vector<std::vector<Value>> pointers; 71 std::vector<std::vector<Value>> indices; 72 /// Sparse iteration information (by tensor and index). These arrays 73 /// are updated to remain current within the current loop. 74 std::vector<std::vector<Value>> highs; 75 std::vector<std::vector<Value>> pidxs; 76 std::vector<std::vector<Value>> idxs; 77 /// Current reduction, updated during code generation. When indices of a 78 /// reduction are exhausted, all inner loops can use a scalarized reduction. 79 unsigned redExp = -1u; 80 Value redVal; 81 Reduction redKind = kNoReduc; 82 // Sparse tensor as output. Implemented either through direct injective 83 // insertion in lexicographic index order (where indices are updated 84 // in the temporary array `lexIdx`) or through access pattern expansion 85 // in the innermost loop nest (`expValues` through `expCount`). 86 OpOperand *sparseOut; 87 unsigned outerParNest; 88 Value lexIdx; 89 Value expValues; 90 Value expFilled; 91 Value expAdded; 92 Value expCount; 93 // Current vector length and mask. 94 unsigned curVecLength = 1; 95 Value curVecMask; 96 }; 97 98 } // namespace 99 100 //===----------------------------------------------------------------------===// 101 // Sparse compiler analysis methods. 102 //===----------------------------------------------------------------------===// 103 104 /// Helper method to apply dimension ordering permutation. 105 static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d) { 106 if (enc) { 107 auto order = enc.getDimOrdering(); 108 if (order) { 109 assert(order.isPermutation()); 110 return order.getDimPosition(d); 111 } 112 } 113 return d; 114 } 115 116 /// Helper method to translate dim level type to internal representation. 117 static Dim toDim(const SparseTensorEncodingAttr &enc, unsigned d) { 118 if (enc) { 119 SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d]; 120 if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed) 121 return Dim::kSparse; 122 if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton) 123 return Dim::kSingle; 124 } 125 return Dim::kDense; 126 } 127 128 /// Helper method to inspect affine expressions. Rejects cases where the 129 /// same index is used more than once. Also rejects affine expressions 130 /// that are not a direct index for annotated tensors. 131 // TODO: accept more affine cases for sparse tensors 132 static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a, Dim dim, 133 bool isDense) { 134 switch (a.getKind()) { 135 case AffineExprKind::DimId: { 136 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 137 if (!merger.isDim(tensor, idx, Dim::kUndef)) 138 return false; // used more than once 139 merger.setDim(tensor, idx, dim); 140 return true; 141 } 142 case AffineExprKind::Add: 143 case AffineExprKind::Mul: { 144 if (!isDense) 145 return false; 146 auto binOp = a.cast<AffineBinaryOpExpr>(); 147 return findAffine(merger, tensor, binOp.getLHS(), dim, isDense) && 148 findAffine(merger, tensor, binOp.getRHS(), dim, isDense); 149 } 150 case AffineExprKind::Constant: 151 return isDense; 152 default: 153 return false; 154 } 155 } 156 157 /// Helper method to inspect sparse encodings in the tensor types. 158 /// Fills the per-dimension sparsity information for all tensors. 159 /// Returns true if the sparse annotations and affine subscript 160 /// expressions of all tensors are admissable. Returns false if 161 /// no annotations are found or inadmissable constructs occur. 162 static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) { 163 bool annotated = false; 164 for (OpOperand *t : op.getInputAndOutputOperands()) { 165 auto map = op.getTiedIndexingMap(t); 166 auto enc = getSparseTensorEncoding(t->get().getType()); 167 if (enc) 168 annotated = true; 169 assert(map.getNumResults() == op.getRank(t)); 170 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 171 unsigned tensor = t->getOperandNumber(); 172 AffineExpr a = map.getResult(perm(enc, d)); 173 if (!findAffine(merger, tensor, a, toDim(enc, d), !enc)) 174 return false; // inadmissable affine expression 175 } 176 } 177 return annotated; 178 } 179 180 /// A DFS helper to compute a topological sort. Note that recursion is 181 /// bounded by the number of implicit loops, which is always small. 182 /// Returns false when a cycle is detected. 183 static bool topSortDFS(unsigned i, std::vector<unsigned> &visit, 184 std::vector<unsigned> &topSort, 185 std::vector<std::vector<bool>> &adjM) { 186 if (visit[i] != 0) 187 return visit[i] != 1; // 1 denotes cycle! 188 visit[i] = 1; 189 for (unsigned j = 0, e = visit.size(); j < e; j++) 190 if (adjM[i][j]) 191 if (!topSortDFS(j, visit, topSort, adjM)) 192 return false; 193 visit[i] = 2; 194 topSort.push_back(i); 195 return true; 196 } 197 198 /// Helper method to add all constraints from the indices in one affine 199 /// expression before all indices in the other affine expression. For 200 /// example i0+i1 < i2+i3+1 yields i0<i2, i0<i3, i1<i2, and i1<i3. 201 static void addAffineOrderings(std::vector<std::vector<bool>> &adjM, 202 AffineExpr a, AffineExpr b, unsigned fidx) { 203 switch (a.getKind()) { 204 case AffineExprKind::DimId: { 205 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 206 if (b) 207 addAffineOrderings(adjM, b, AffineExpr(), idx); 208 else 209 adjM[fidx][idx] = true; 210 break; 211 } 212 case AffineExprKind::Add: 213 case AffineExprKind::Mul: { 214 auto binOp = a.cast<AffineBinaryOpExpr>(); 215 addAffineOrderings(adjM, binOp.getLHS(), b, fidx); 216 addAffineOrderings(adjM, binOp.getRHS(), b, fidx); 217 break; 218 } 219 default: 220 break; 221 } 222 } 223 224 /// Computes a topologically sorted iteration graph for the linalg operation. 225 /// Ensures all tensors are visited in natural index order. This is essential 226 /// for sparse storage formats since these only support access along fixed 227 /// dimensions. Even for dense storage formats, however, the natural index 228 /// order yields innermost unit-stride access with better spatial locality. 229 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op, 230 std::vector<unsigned> &topSort, 231 unsigned mask) { 232 // Set up an n x n from/to adjacency matrix of the iteration graph 233 // for the implicit loop indices i_0 .. i_n-1. 234 unsigned n = op.getNumLoops(); 235 std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false)); 236 237 // Iterate over the indexing maps of every tensor in the tensor expression. 238 for (OpOperand *t : op.getInputAndOutputOperands()) { 239 auto map = op.getTiedIndexingMap(t); 240 auto enc = getSparseTensorEncoding(t->get().getType()); 241 assert(map.getNumDims() == n); 242 // Skip dense tensor constraints when not requested. 243 if (!(mask & SortMask::kIncludeDense) && !enc) 244 continue; 245 // Each tensor expression and optional dimension ordering (row-major 246 // by default) puts an ordering constraint on the loop indices. For 247 // example, the tensor expresion A_ijk forces the ordering i < j < k 248 // on the loop indices if no explicit dimension ordering is given. 249 for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) { 250 AffineExpr f = map.getResult(perm(enc, d - 1)); 251 AffineExpr t = map.getResult(perm(enc, d)); 252 addAffineOrderings(adjM, f, t, 0); 253 } 254 // Push unrelated loops into sparse iteration space, so these 255 // will be skipped more often. 256 if (mask & SortMask::kIncludeUndef) { 257 unsigned tensor = t->getOperandNumber(); 258 for (unsigned i = 0; i < n; i++) 259 if (merger.isDim(tensor, i, Dim::kSparse)) 260 for (unsigned j = 0; j < n; j++) 261 if (merger.isDim(tensor, j, Dim::kUndef)) 262 adjM[i][j] = true; 263 } 264 } 265 266 // Topologically sort the iteration graph to determine loop order. 267 // Report failure for a cyclic iteration graph. 268 topSort.clear(); 269 topSort.reserve(n); 270 std::vector<unsigned> visit(n, 0); 271 for (unsigned i = 0; i < n; i++) 272 if (visit[i] == 0) 273 if (!topSortDFS(i, visit, topSort, adjM)) 274 return false; // cycle! 275 std::reverse(std::begin(topSort), std::end(topSort)); 276 return true; 277 } 278 279 /// Returns true if tensor has an in-place annotation. 280 static bool isInPlace(Value val) { 281 if (auto arg = val.dyn_cast<BlockArgument>()) 282 if (auto funcOp = dyn_cast<FuncOp>(arg.getOwner()->getParentOp())) 283 if (auto attr = funcOp.getArgAttrOfType<BoolAttr>( 284 arg.getArgNumber(), 285 bufferization::BufferizableOpInterface::kInplaceableAttrName)) 286 return attr.getValue(); 287 return false; 288 } 289 290 /// Returns true if tensor materializes uninitialized into the computation. 291 static bool isMaterializing(Value val) { 292 return val.getDefiningOp<linalg::InitTensorOp>() || 293 val.getDefiningOp<InitOp>(); 294 } 295 296 /// Returns true when the tensor expression is admissable for codegen. 297 /// Since all sparse input tensors are admissable, we just need to check 298 /// whether the out tensor in the tensor expression codegen is admissable. 299 /// Sets `sparseOut` to the tensor and `outerParNest` to the outer injective 300 /// nesting depth when a "truly dynamic" sparse tensor output occurs. 301 static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op, 302 std::vector<unsigned> &topSort, unsigned exp, 303 OpOperand **sparseOut, 304 unsigned &outerParNest) { 305 OpOperand *lhs = op.getOutputOperand(0); 306 unsigned tensor = lhs->getOperandNumber(); 307 auto enc = getSparseTensorEncoding(lhs->get().getType()); 308 // An non-annotated output tensor is assumed dense, and becomes a random 309 // access n-dim memref. Admissable since insertions cannot occur. 310 if (!enc) 311 return true; 312 // An all-dense annotated "sparse" output tensor becomes a linearized random 313 // access 1-dim memref. Also admissable since insertions cannot occur. 314 bool allDense = true; 315 auto iteratorTypes = op.iterator_types().getValue(); 316 unsigned numLoops = iteratorTypes.size(); 317 for (unsigned i = 0; i < numLoops; i++) 318 if (merger.isDim(tensor, i, Dim::kSparse)) { 319 allDense = false; 320 break; 321 } 322 if (allDense) 323 return true; 324 // A tensor expression with a sparse output tensor that changes its values 325 // but not its nonzero structure, an operation called "simply dynamic" in 326 // [Bik96,Ch9], is also admissable without special codegen, provided 327 // the tensor's underlying sparse storage scheme can be modified in place. 328 if (merger.isSingleCondition(tensor, exp) && isInPlace(lhs->get())) 329 return true; 330 // Accept "truly dynamic" if the output tensor materializes uninitialized 331 // into the computation and insertions occur in lexicographic index order. 332 if (isMaterializing(lhs->get())) { 333 unsigned nest = 0; 334 for (unsigned i = 0; i < numLoops; i++) { 335 if (isReductionIterator(iteratorTypes[topSort[i]])) 336 break; // terminate at first reduction 337 nest++; 338 } 339 // Determine admissable dynamic insertion situations: 340 // (1) fully injective, since there are no reductions, 341 // (2) admissable 1-d expansion in innermost dimension. 342 if (nest >= op.getRank(lhs) - 1) { 343 *sparseOut = lhs; 344 outerParNest = nest; 345 return true; 346 } 347 } 348 return false; 349 } 350 351 //===----------------------------------------------------------------------===// 352 // Sparse compiler synthesis methods (reductions). 353 //===----------------------------------------------------------------------===// 354 355 /// Maps reduction kind to vector::CombiningKind. 356 static vector::CombiningKind getCombiningKind(Reduction kind) { 357 switch (kind) { 358 case kNoReduc: 359 break; 360 case kSum: 361 return vector::CombiningKind::ADD; 362 case kProduct: 363 return vector::CombiningKind::MUL; 364 case kAnd: 365 return vector::CombiningKind::AND; 366 case kOr: 367 return vector::CombiningKind::OR; 368 case kXor: 369 return vector::CombiningKind::XOR; 370 } 371 llvm_unreachable("unknown reduction kind"); 372 } 373 374 /// Maps operation to reduction. 375 static Reduction getReduction(Kind kind) { 376 switch (kind) { 377 case Kind::kAddF: 378 case Kind::kAddI: 379 case Kind::kSubF: 380 case Kind::kSubI: 381 return kSum; 382 case Kind::kMulF: 383 case Kind::kMulI: 384 return kProduct; 385 case Kind::kAndI: 386 return kAnd; 387 case Kind::kOrI: 388 return kOr; 389 case Kind::kXorI: 390 return kXor; 391 default: 392 llvm_unreachable("unexpected reduction operator"); 393 } 394 } 395 396 /// Generates an initial value for a vector reduction, following the scheme 397 /// given in Chapter 5 of "The Software Vectorization Handbook", where the 398 /// initial scalar value is correctly embedded in the vector reduction value, 399 /// and a straightforward horizontal reduction will complete the operation. 400 static Value genVectorReducInit(CodeGen &codegen, PatternRewriter &rewriter, 401 Location loc, VectorType vtp) { 402 Value r = codegen.redVal; 403 switch (codegen.redKind) { 404 case kNoReduc: 405 break; 406 case kSum: 407 case kXor: 408 // Initialize reduction vector to: | 0 | .. | 0 | r | 409 return rewriter.create<vector::InsertElementOp>( 410 loc, r, constantZero(rewriter, loc, vtp), 411 constantIndex(rewriter, loc, 0)); 412 case kProduct: 413 // Initialize reduction vector to: | 1 | .. | 1 | r | 414 return rewriter.create<vector::InsertElementOp>( 415 loc, r, constantOne(rewriter, loc, vtp), 416 constantIndex(rewriter, loc, 0)); 417 case kAnd: 418 case kOr: 419 // Initialize reduction vector to: | r | .. | r | r | 420 return rewriter.create<vector::BroadcastOp>(loc, vtp, r); 421 } 422 llvm_unreachable("unknown reduction kind"); 423 } 424 425 /// Generates final value for a vector reduction. 426 static Value genVectorReducEnd(CodeGen &codegen, PatternRewriter &rewriter, 427 Location loc, VectorType vtp) { 428 vector::CombiningKind kind = getCombiningKind(codegen.redKind); 429 return rewriter.create<vector::ReductionOp>(loc, kind, codegen.redVal); 430 } 431 432 /// Updates scalarized reduction value. 433 static void updateReduc(Merger &merger, CodeGen &codegen, Value reduc) { 434 assert(codegen.redKind != kNoReduc); 435 codegen.redVal = merger.exp(codegen.redExp).val = reduc; 436 } 437 438 //===----------------------------------------------------------------------===// 439 // Sparse compiler synthesis methods (statements and expressions). 440 //===----------------------------------------------------------------------===// 441 442 /// Generates buffer for the output tensor. Note that all sparse kernels 443 /// assume that when all elements are written to (viz. x(i) = y(i) * z(i)), 444 /// the output buffer is already initialized to all zeroes and only nonzeroes 445 /// values are computed and written out. For updates (viz. x(i) += y(i) * z(i)), 446 /// only nonzeroes values are used for the updates and no assumption on the 447 /// original contents of the output buffer is necessary.. 448 static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter, 449 linalg::GenericOp op, MemRefType denseTp, 450 ArrayRef<Value> args) { 451 Location loc = op.getLoc(); 452 Value tensor = op.getOutputOperand(0)->get(); 453 // The output tensor simply could materialize from the buffer that will 454 // be generated for the tensor present in the outs() clause. This has 455 // the major advantage that the sparse kernel only updates the nonzero 456 // positions for the output tensor. 457 if (isInPlace(tensor)) 458 return rewriter.create<bufferization::ToMemrefOp>(loc, denseTp, tensor); 459 // By default, a new buffer is allocated which is initialized to the 460 // tensor defined in the outs() clause. This is always correct but 461 // introduces a dense initialization component that may negatively 462 // impact the running complexity of the sparse kernel. If the tensor 463 // materializes into the computation, we need to preserve the zero 464 // initialization assumption of all sparse output buffers. 465 Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args); 466 if (isMaterializing(tensor)) { 467 Value zero = constantZero(rewriter, loc, denseTp.getElementType()); 468 rewriter.create<linalg::FillOp>(loc, zero, alloc); 469 } else { 470 Value init = 471 rewriter.create<bufferization::ToMemrefOp>(loc, denseTp, tensor); 472 rewriter.create<memref::CopyOp>(loc, init, alloc); 473 } 474 return alloc; 475 } 476 477 /// Local bufferization of all dense and sparse data structures. 478 /// This code enables testing the first prototype sparse compiler. 479 // TODO: replace this with a proliferated bufferization strategy 480 static void genBuffers(Merger &merger, CodeGen &codegen, 481 PatternRewriter &rewriter, linalg::GenericOp op) { 482 Location loc = op.getLoc(); 483 assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1); 484 // For every tensor, find lower and upper bound on dimensions, set the 485 // same bounds on loop indices, and obtain dense or sparse buffer(s). 486 SmallVector<Value, 4> args; 487 for (OpOperand *t : op.getInputAndOutputOperands()) { 488 unsigned tensor = t->getOperandNumber(); 489 auto shape = op.getShape(t); 490 auto map = op.getTiedIndexingMap(t); 491 auto enc = getSparseTensorEncoding(t->get().getType()); 492 // Scan all dimensions of current tensor. 493 args.clear(); 494 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 495 AffineExpr a = map.getResult(perm(enc, d)); 496 if (a.getKind() != AffineExprKind::DimId) 497 continue; // compound 498 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 499 // Handle sparse storage schemes. 500 if (merger.isDim(tensor, idx, Dim::kSparse)) { 501 auto dynShape = {ShapedType::kDynamicSize}; 502 auto ptrTp = 503 MemRefType::get(dynShape, getPointerOverheadType(rewriter, enc)); 504 auto indTp = 505 MemRefType::get(dynShape, getIndexOverheadType(rewriter, enc)); 506 Value dim = constantIndex(rewriter, loc, d); 507 // Generate sparse primitives to obtains pointer and indices. 508 codegen.pointers[tensor][idx] = 509 rewriter.create<ToPointersOp>(loc, ptrTp, t->get(), dim); 510 codegen.indices[tensor][idx] = 511 rewriter.create<ToIndicesOp>(loc, indTp, t->get(), dim); 512 } 513 // Find upper bound in current dimension. 514 unsigned p = perm(enc, d); 515 Value up = linalg::createOrFoldDimOp(rewriter, loc, t->get(), p); 516 if (ShapedType::isDynamic(shape[p])) 517 args.push_back(up); 518 assert(codegen.highs[tensor][idx] == nullptr); 519 codegen.sizes[idx] = codegen.highs[tensor][idx] = up; 520 } 521 // Perform the required bufferization. Dense inputs materialize 522 // from the input tensors. Dense outputs need special handling. 523 // Sparse inputs use sparse primitives to obtain the values. 524 // We also accept in-place all-dense annotated "sparse" outputs. 525 Type elementType = getElementTypeOrSelf(t->get().getType()); 526 if (!enc) { 527 // Non-annotated dense tensors. 528 auto denseTp = MemRefType::get(shape, elementType); 529 if (tensor < op.getNumInputs()) 530 codegen.buffers[tensor] = 531 rewriter.create<bufferization::ToMemrefOp>(loc, denseTp, t->get()); 532 else 533 codegen.buffers[tensor] = 534 genOutputBuffer(codegen, rewriter, op, denseTp, args); 535 } else if (t == codegen.sparseOut) { 536 // True sparse output needs a lexIdx array. 537 Value rank = constantIndex(rewriter, loc, op.getRank(t)); 538 auto dynShape = {ShapedType::kDynamicSize}; 539 auto memTp = MemRefType::get(dynShape, rewriter.getIndexType()); 540 codegen.lexIdx = rewriter.create<memref::AllocaOp>(loc, memTp, rank); 541 } else { 542 // Annotated sparse tensors. 543 auto dynShape = {ShapedType::kDynamicSize}; 544 auto sparseTp = MemRefType::get(dynShape, elementType); 545 codegen.buffers[tensor] = 546 rewriter.create<ToValuesOp>(loc, sparseTp, t->get()); 547 } 548 } 549 } 550 551 /// Constructs vector type. 552 static VectorType vectorType(CodeGen &codegen, Type etp) { 553 return VectorType::get(codegen.curVecLength, etp); 554 } 555 556 /// Constructs vector type from pointer. 557 static VectorType vectorType(CodeGen &codegen, Value ptr) { 558 return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType()); 559 } 560 561 /// Constructs vector iteration mask. 562 static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter, 563 Value iv, Value lo, Value hi, Value step) { 564 Location loc = iv.getLoc(); 565 VectorType mtp = vectorType(codegen, rewriter.getI1Type()); 566 // Special case if the vector length evenly divides the trip count (for 567 // example, "for i = 0, 128, 16"). A constant all-true mask is generated 568 // so that all subsequent masked memory operations are immediately folded 569 // into unconditional memory operations. 570 IntegerAttr loInt, hiInt, stepInt; 571 if (matchPattern(lo, m_Constant(&loInt)) && 572 matchPattern(hi, m_Constant(&hiInt)) && 573 matchPattern(step, m_Constant(&stepInt))) { 574 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) 575 return rewriter.create<vector::BroadcastOp>( 576 loc, mtp, constantI1(rewriter, loc, true)); 577 } 578 // Otherwise, generate a vector mask that avoids overrunning the upperbound 579 // during vector execution. Here we rely on subsequent loop optimizations to 580 // avoid executing the mask in all iterations, for example, by splitting the 581 // loop into an unconditional vector loop and a scalar cleanup loop. 582 auto minMap = AffineMap::get( 583 /*dimCount=*/2, /*symbolCount=*/1, 584 {rewriter.getAffineSymbolExpr(0), 585 rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)}, 586 rewriter.getContext()); 587 Value end = 588 rewriter.createOrFold<AffineMinOp>(loc, minMap, ValueRange{hi, iv, step}); 589 return rewriter.create<vector::CreateMaskOp>(loc, mtp, end); 590 } 591 592 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi]. 593 static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter, 594 Value ptr, ArrayRef<Value> args) { 595 Location loc = ptr.getLoc(); 596 VectorType vtp = vectorType(codegen, ptr); 597 Value pass = constantZero(rewriter, loc, vtp); 598 if (args.back().getType().isa<VectorType>()) { 599 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 600 Value indexVec = args.back(); 601 scalarArgs.back() = constantIndex(rewriter, loc, 0); 602 return rewriter.create<vector::GatherOp>( 603 loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass); 604 } 605 return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args, 606 codegen.curVecMask, pass); 607 } 608 609 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs. 610 static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter, 611 Value rhs, Value ptr, ArrayRef<Value> args) { 612 Location loc = ptr.getLoc(); 613 if (args.back().getType().isa<VectorType>()) { 614 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 615 Value indexVec = args.back(); 616 scalarArgs.back() = constantIndex(rewriter, loc, 0); 617 rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, 618 codegen.curVecMask, rhs); 619 return; 620 } 621 rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask, 622 rhs); 623 } 624 625 /// Generates a vectorized invariant. Here we rely on subsequent loop 626 /// optimizations to hoist the invariant broadcast out of the vector loop. 627 static Value genVectorInvariantValue(CodeGen &codegen, 628 PatternRewriter &rewriter, Value val) { 629 VectorType vtp = vectorType(codegen, val.getType()); 630 return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val); 631 } 632 633 /// Generates an affine expression. 634 // 635 // TODO: generalize for sparse tensor subscripts 636 // 637 static Value genAffine(CodeGen &codegen, PatternRewriter &rewriter, 638 AffineExpr a, Location loc) { 639 switch (a.getKind()) { 640 case AffineExprKind::DimId: { 641 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 642 return codegen.loops[idx]; // universal dense index 643 } 644 case AffineExprKind::Add: { 645 auto binOp = a.cast<AffineBinaryOpExpr>(); 646 return rewriter.create<arith::AddIOp>( 647 loc, genAffine(codegen, rewriter, binOp.getLHS(), loc), 648 genAffine(codegen, rewriter, binOp.getRHS(), loc)); 649 } 650 case AffineExprKind::Mul: { 651 auto binOp = a.cast<AffineBinaryOpExpr>(); 652 return rewriter.create<arith::MulIOp>( 653 loc, genAffine(codegen, rewriter, binOp.getLHS(), loc), 654 genAffine(codegen, rewriter, binOp.getRHS(), loc)); 655 } 656 case AffineExprKind::Constant: { 657 int64_t c = a.cast<AffineConstantExpr>().getValue(); 658 return constantIndex(rewriter, loc, c); 659 } 660 default: 661 llvm_unreachable("unexpected affine subscript"); 662 } 663 } 664 665 /// Generates index for load/store on sparse tensor. 666 static Value genIndex(CodeGen &codegen, linalg::GenericOp op, OpOperand *t) { 667 auto map = op.getTiedIndexingMap(t); 668 auto enc = getSparseTensorEncoding(t->get().getType()); 669 AffineExpr a = map.getResult(perm(enc, map.getNumResults() - 1)); 670 assert(a.getKind() == AffineExprKind::DimId); 671 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 672 return codegen.loops[idx]; 673 } 674 675 /// Generates subscript for load/store on a dense or sparse tensor. 676 static Value genSubscript(CodeGen &codegen, PatternRewriter &rewriter, 677 linalg::GenericOp op, OpOperand *t, 678 SmallVector<Value, 4> &args) { 679 unsigned tensor = t->getOperandNumber(); 680 auto map = op.getTiedIndexingMap(t); 681 auto enc = getSparseTensorEncoding(t->get().getType()); 682 unsigned rank = map.getNumResults(); 683 if (enc) { 684 // Note that currently, all sparse subscripts are simple. 685 // TODO: accept affine too? 686 AffineExpr a = map.getResult(perm(enc, rank - 1)); 687 assert(a.getKind() == AffineExprKind::DimId); 688 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 689 assert(codegen.pidxs[tensor][idx] != nullptr); 690 args.push_back(codegen.pidxs[tensor][idx]); // position index 691 } else { 692 for (unsigned d = 0; d < rank; d++) { 693 AffineExpr a = map.getResult(perm(enc, d)); 694 args.push_back(genAffine(codegen, rewriter, a, op.getLoc())); 695 } 696 } 697 return codegen.buffers[tensor]; 698 } 699 700 /// Generates insertion code to implement dynamic tensor load. 701 static Value genInsertionLoad(CodeGen &codegen, PatternRewriter &rewriter, 702 linalg::GenericOp op, OpOperand *t) { 703 Location loc = op.getLoc(); 704 // Direct lexicographic index order, tensor loads as zero. 705 if (!codegen.expValues) { 706 Type tp = getElementTypeOrSelf(t->get().getType()); 707 return constantZero(rewriter, loc, tp); 708 } 709 // Load from expanded access pattern. 710 Value index = genIndex(codegen, op, t); 711 return rewriter.create<memref::LoadOp>(loc, codegen.expValues, index); 712 } 713 714 /// Generates insertion code to implement dynamic tensor store. 715 static void genInsertionStore(CodeGen &codegen, PatternRewriter &rewriter, 716 linalg::GenericOp op, OpOperand *t, Value rhs) { 717 Location loc = op.getLoc(); 718 // Direct insertion in lexicographic index order. 719 if (!codegen.expValues) { 720 rewriter.create<LexInsertOp>(loc, t->get(), codegen.lexIdx, rhs); 721 return; 722 } 723 // Generates insertion code along expanded access pattern. 724 // if (!expFilled[i]) then 725 // expFilled[i] = true 726 // expAdded[inserts++] = i 727 // endif 728 // values[i] = rhs 729 Value index = genIndex(codegen, op, t); 730 Value fval = constantI1(rewriter, loc, false); 731 Value tval = constantI1(rewriter, loc, true); 732 // If statement. 733 Value filled = rewriter.create<memref::LoadOp>(loc, codegen.expFilled, index); 734 Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 735 filled, fval); 736 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, rewriter.getIndexType(), 737 cond, /*else=*/true); 738 // True branch. 739 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); 740 rewriter.create<memref::StoreOp>(loc, tval, codegen.expFilled, index); 741 rewriter.create<memref::StoreOp>(loc, index, codegen.expAdded, 742 codegen.expCount); 743 Value one = constantIndex(rewriter, loc, 1); 744 Value add = rewriter.create<arith::AddIOp>(loc, codegen.expCount, one); 745 rewriter.create<scf::YieldOp>(loc, add); 746 // False branch. 747 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); 748 rewriter.create<scf::YieldOp>(loc, codegen.expCount); 749 rewriter.setInsertionPointAfter(ifOp); 750 // Value assignment. 751 codegen.expCount = ifOp.getResult(0); 752 rewriter.create<memref::StoreOp>(loc, rhs, codegen.expValues, index); 753 } 754 755 /// Generates a load on a dense or sparse tensor. 756 static Value genTensorLoad(Merger &merger, CodeGen &codegen, 757 PatternRewriter &rewriter, linalg::GenericOp op, 758 unsigned exp) { 759 // Test if the load was hoisted to a higher loop nest. 760 Value val = merger.exp(exp).val; 761 if (val) { 762 if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>()) 763 return genVectorInvariantValue(codegen, rewriter, val); 764 return val; 765 } 766 // Load during insertion. 767 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 768 if (t == codegen.sparseOut) 769 return genInsertionLoad(codegen, rewriter, op, t); 770 // Actual load. 771 SmallVector<Value, 4> args; 772 Value ptr = genSubscript(codegen, rewriter, op, t, args); 773 if (codegen.curVecLength > 1) 774 return genVectorLoad(codegen, rewriter, ptr, args); 775 return rewriter.create<memref::LoadOp>(op.getLoc(), ptr, args); 776 } 777 778 /// Generates a store on a dense or sparse tensor. 779 static void genTensorStore(Merger &merger, CodeGen &codegen, 780 PatternRewriter &rewriter, linalg::GenericOp op, 781 Value rhs) { 782 Location loc = op.getLoc(); 783 // Test if this is a scalarized reduction. 784 if (codegen.redVal) { 785 if (codegen.curVecLength > 1) 786 rhs = rewriter.create<arith::SelectOp>(loc, codegen.curVecMask, rhs, 787 codegen.redVal); 788 updateReduc(merger, codegen, rhs); 789 return; 790 } 791 // Store during insertion. 792 OpOperand *t = op.getOutputOperand(0); 793 if (t == codegen.sparseOut) { 794 genInsertionStore(codegen, rewriter, op, t, rhs); 795 return; 796 } 797 // Actual store. 798 SmallVector<Value, 4> args; 799 Value ptr = genSubscript(codegen, rewriter, op, t, args); 800 if (codegen.curVecLength > 1) 801 genVectorStore(codegen, rewriter, rhs, ptr, args); 802 else 803 rewriter.create<memref::StoreOp>(loc, rhs, ptr, args); 804 } 805 806 /// Generates a pointer/index load from the sparse storage scheme. Narrower 807 /// data types need to be zero extended before casting the value into the 808 /// index type used for looping and indexing. 809 static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc, 810 Value ptr, Value s) { 811 // See https://llvm.org/docs/GetElementPtr.html for some background on 812 // the complications described below. 813 if (codegen.curVecLength > 1) { 814 // Since the index vector is used in a subsequent gather/scatter operations, 815 // which effectively defines an unsigned pointer + signed index, we must 816 // zero extend the vector to an index width. For 8-bit and 16-bit values, 817 // an 32-bit index width suffices. For 32-bit values, zero extending the 818 // elements into 64-bit loses some performance since the 32-bit indexed 819 // gather/scatter is more efficient than the 64-bit index variant (if the 820 // negative 32-bit index space is unused, the enableSIMDIndex32 flag can 821 // preserve this performance). For 64-bit values, there is no good way 822 // to state that the indices are unsigned, with creates the potential of 823 // incorrect address calculations in the unlikely case we need such 824 // extremely large offsets. 825 Type etp = ptr.getType().cast<MemRefType>().getElementType(); 826 Value vload = genVectorLoad(codegen, rewriter, ptr, {s}); 827 if (!etp.isa<IndexType>()) { 828 if (etp.getIntOrFloatBitWidth() < 32) 829 vload = rewriter.create<arith::ExtUIOp>( 830 loc, vectorType(codegen, rewriter.getI32Type()), vload); 831 else if (etp.getIntOrFloatBitWidth() < 64 && 832 !codegen.options.enableSIMDIndex32) 833 vload = rewriter.create<arith::ExtUIOp>( 834 loc, vectorType(codegen, rewriter.getI64Type()), vload); 835 } 836 return vload; 837 } 838 // For the scalar case, we simply zero extend narrower indices into 64-bit 839 // values before casting to index without a performance penalty. Here too, 840 // however, indices that already are 64-bit, in theory, cannot express the 841 // full range as explained above. 842 Value load = rewriter.create<memref::LoadOp>(loc, ptr, s); 843 if (!load.getType().isa<IndexType>()) { 844 if (load.getType().getIntOrFloatBitWidth() < 64) 845 load = rewriter.create<arith::ExtUIOp>(loc, rewriter.getI64Type(), load); 846 load = 847 rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), load); 848 } 849 return load; 850 } 851 852 /// Generates an invariant value. 853 static Value genInvariantValue(Merger &merger, CodeGen &codegen, 854 PatternRewriter &rewriter, unsigned exp) { 855 Value val = merger.exp(exp).val; 856 if (codegen.curVecLength > 1) 857 return genVectorInvariantValue(codegen, rewriter, val); 858 return val; 859 } 860 861 /// Generates an address computation "sz * p + i". 862 static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter, 863 Location loc, Value size, Value p, Value i) { 864 Value mul = rewriter.create<arith::MulIOp>(loc, size, p); 865 if (auto vtp = i.getType().dyn_cast<VectorType>()) { 866 Value inv = 867 rewriter.create<arith::IndexCastOp>(loc, vtp.getElementType(), mul); 868 mul = genVectorInvariantValue(codegen, rewriter, inv); 869 } 870 return rewriter.create<arith::AddIOp>(loc, mul, i); 871 } 872 873 /// Recursively generates tensor expression. 874 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 875 linalg::GenericOp op, unsigned exp) { 876 Location loc = op.getLoc(); 877 if (exp == -1u) 878 return Value(); 879 if (merger.exp(exp).kind == Kind::kTensor) 880 return genTensorLoad(merger, codegen, rewriter, op, exp); 881 if (merger.exp(exp).kind == Kind::kInvariant) 882 return genInvariantValue(merger, codegen, rewriter, exp); 883 Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0); 884 Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1); 885 return merger.buildExp(rewriter, loc, exp, v0, v1); 886 } 887 888 /// Determines if affine expression is invariant. 889 static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a, 890 unsigned ldx, bool &atLevel) { 891 switch (a.getKind()) { 892 case AffineExprKind::DimId: { 893 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 894 if (idx == ldx) 895 atLevel = true; 896 return codegen.loops[idx] != nullptr; // no longer in play? 897 } 898 case AffineExprKind::Add: 899 case AffineExprKind::Mul: { 900 auto binOp = a.cast<AffineBinaryOpExpr>(); 901 return isInvariantAffine(codegen, binOp.getLHS(), ldx, atLevel) && 902 isInvariantAffine(codegen, binOp.getRHS(), ldx, atLevel); 903 } 904 default: 905 return true; 906 } 907 } 908 909 /// Hoists loop invariant tensor loads for which indices have been exhausted. 910 static void genInvariants(Merger &merger, CodeGen &codegen, 911 PatternRewriter &rewriter, linalg::GenericOp op, 912 unsigned exp, unsigned ldx, bool atStart, 913 Kind last = Kind::kTensor) { 914 if (exp == -1u) 915 return; 916 if (merger.exp(exp).kind == Kind::kTensor) { 917 // Inspect tensor indices. 918 bool atLevel = ldx == -1u; 919 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 920 auto map = op.getTiedIndexingMap(t); 921 auto enc = getSparseTensorEncoding(t->get().getType()); 922 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 923 AffineExpr a = map.getResult(perm(enc, d)); 924 if (!isInvariantAffine(codegen, a, ldx, atLevel)) 925 return; // still in play 926 } 927 // All exhausted at this level (atLevel denotes exactly at this level). 928 if (!atLevel) 929 return; 930 OpOperand *lhs = op.getOutputOperand(0); 931 if (lhs == t) { 932 // Start or end a scalarized reduction 933 if (atStart) { 934 Value load = genTensorLoad(merger, codegen, rewriter, op, exp); 935 codegen.redKind = getReduction(last); 936 codegen.redExp = exp; 937 updateReduc(merger, codegen, load); 938 } else { 939 Value redVal = codegen.redVal; 940 updateReduc(merger, codegen, Value()); 941 codegen.redExp = -1u; 942 codegen.redKind = kNoReduc; 943 genTensorStore(merger, codegen, rewriter, op, redVal); 944 } 945 } else { 946 // Start or end loop invariant hoisting of a tensor load. 947 merger.exp(exp).val = 948 atStart ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value(); 949 } 950 } else if (merger.exp(exp).kind != Kind::kInvariant) { 951 // Traverse into the binary operations. Note that we only hoist 952 // tensor loads, since subsequent MLIR/LLVM passes know how to 953 // deal with all other kinds of derived loop invariants. 954 Kind last = merger.exp(exp).kind; 955 unsigned e0 = merger.exp(exp).children.e0; 956 unsigned e1 = merger.exp(exp).children.e1; 957 genInvariants(merger, codegen, rewriter, op, e0, ldx, atStart, last); 958 genInvariants(merger, codegen, rewriter, op, e1, ldx, atStart, last); 959 } 960 } 961 962 /// Generates an expanded access pattern in innermost dimension. 963 static void genExpansion(Merger &merger, CodeGen &codegen, 964 PatternRewriter &rewriter, linalg::GenericOp op, 965 unsigned at, bool atStart) { 966 OpOperand *lhs = codegen.sparseOut; 967 if (!lhs || codegen.outerParNest != op.getRank(lhs) - 1 || 968 at != codegen.outerParNest) 969 return; // not needed at this level 970 // Generate start or end of an expanded access pattern. 971 Value tensor = lhs->get(); 972 Location loc = op.getLoc(); 973 if (atStart) { 974 auto dynShape = {ShapedType::kDynamicSize}; 975 Type etp = tensor.getType().cast<ShapedType>().getElementType(); 976 Type t1 = MemRefType::get(dynShape, etp); 977 Type t2 = MemRefType::get(dynShape, rewriter.getI1Type()); 978 Type t3 = MemRefType::get(dynShape, rewriter.getIndexType()); 979 Type t4 = rewriter.getIndexType(); 980 auto res = 981 rewriter.create<ExpandOp>(loc, TypeRange({t1, t2, t3, t4}), tensor); 982 assert(res.getNumResults() == 4); 983 assert(!codegen.expValues); 984 codegen.expValues = res.getResult(0); 985 codegen.expFilled = res.getResult(1); 986 codegen.expAdded = res.getResult(2); 987 codegen.expCount = res.getResult(3); 988 } else { 989 assert(codegen.expValues); 990 rewriter.create<CompressOp>(loc, tensor, codegen.lexIdx, codegen.expValues, 991 codegen.expFilled, codegen.expAdded, 992 codegen.expCount); 993 codegen.expValues = codegen.expFilled = codegen.expAdded = 994 codegen.expCount = Value(); 995 } 996 } 997 998 /// Generates initialization code for the subsequent loop sequence at 999 /// current index level. Returns true if the loop sequence needs to 1000 /// maintain the universal index. 1001 static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1002 linalg::GenericOp op, std::vector<unsigned> &topSort, 1003 unsigned at, BitVector &inits) { 1004 bool needsUniv = false; 1005 Location loc = op.getLoc(); 1006 unsigned idx = topSort[at]; 1007 1008 // Initialize sparse positions. 1009 for (unsigned b = 0, be = inits.size(); b < be; b++) { 1010 if (inits[b]) { 1011 unsigned tensor = merger.tensor(b); 1012 assert(idx == merger.index(b)); 1013 if (merger.isDim(b, Dim::kSparse)) { 1014 // Initialize sparse index. 1015 unsigned pat = at; 1016 for (; pat != 0; pat--) { 1017 if (codegen.pidxs[tensor][topSort[pat - 1]]) 1018 break; 1019 } 1020 Value ptr = codegen.pointers[tensor][idx]; 1021 Value one = constantIndex(rewriter, loc, 1); 1022 Value p0 = (pat == 0) ? constantIndex(rewriter, loc, 0) 1023 : codegen.pidxs[tensor][topSort[pat - 1]]; 1024 codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0); 1025 Value p1 = rewriter.create<arith::AddIOp>(loc, p0, one); 1026 codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1); 1027 } else { 1028 // Dense index still in play. 1029 needsUniv = true; 1030 } 1031 } 1032 } 1033 1034 // Initialize the universal dense index. 1035 codegen.loops[idx] = constantIndex(rewriter, loc, 0); 1036 return needsUniv; 1037 } 1038 1039 /// Returns vectorization strategy. Any implicit inner loop in the Linalg 1040 /// operation is a candidate. Whether it is actually converted to SIMD code 1041 /// depends on the requested strategy. 1042 static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) { 1043 switch (codegen.options.vectorizationStrategy) { 1044 case SparseVectorizationStrategy::kNone: 1045 return false; 1046 case SparseVectorizationStrategy::kDenseInnerLoop: 1047 return isInner && !isSparse; 1048 case SparseVectorizationStrategy::kAnyStorageInnerLoop: 1049 return isInner; 1050 } 1051 llvm_unreachable("unexpected vectorization strategy"); 1052 } 1053 1054 /// Returns parallelization strategy. Any implicit loop in the Linalg operation 1055 /// that is marked "parallel" is a candidate. Whether it is actually converted 1056 /// to a parallel operation depends on the requested strategy. 1057 static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction, 1058 bool isSparse, bool isVector) { 1059 switch (codegen.options.parallelizationStrategy) { 1060 case SparseParallelizationStrategy::kNone: 1061 return false; 1062 case SparseParallelizationStrategy::kDenseOuterLoop: 1063 return isOuter && !isSparse && !isReduction && !isVector; 1064 case SparseParallelizationStrategy::kAnyStorageOuterLoop: 1065 return isOuter && !isReduction && !isVector; 1066 case SparseParallelizationStrategy::kDenseAnyLoop: 1067 return !isSparse && !isReduction && !isVector; 1068 case SparseParallelizationStrategy::kAnyStorageAnyLoop: 1069 return !isReduction && !isVector; 1070 } 1071 llvm_unreachable("unexpected parallelization strategy"); 1072 } 1073 1074 /// Checks unit stride for dense tensors. The iteration graph may have ignored 1075 /// dense access patterns in order to avoid cycles (sparse access patterns are 1076 /// always placed innermost), but that means dense access has become strided. 1077 /// This prevents effective vectorization. 1078 static bool denseUnitStrides(Merger &merger, linalg::GenericOp op, 1079 unsigned idx) { 1080 for (OpOperand *t : op.getInputAndOutputOperands()) { 1081 if (!getSparseTensorEncoding(t->get().getType())) { 1082 auto map = op.getTiedIndexingMap(t); 1083 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 1084 AffineExpr a = map.getResult(d); 1085 // Report non-unit stride if innermost index appears at an outer 1086 // dimension (true non-unit stride) or if the innermost index appears 1087 // in a compound subscript in the innermost dimension. Even if the 1088 // latter is unit stride, it does not play well with scatter/gather. 1089 // TODO: accept unit stride affine innermost like a[i,j+k+1]? 1090 if (a.isFunctionOfDim(idx) && 1091 ((d != rank - 1) || (a.getKind() != AffineExprKind::DimId))) 1092 return false; 1093 } 1094 } 1095 } 1096 return true; 1097 } 1098 1099 /// Generates a for-loop on a single index. 1100 static Operation *genFor(Merger &merger, CodeGen &codegen, 1101 PatternRewriter &rewriter, linalg::GenericOp op, 1102 bool isOuter, bool isInner, unsigned idx, 1103 BitVector &indices) { 1104 unsigned fb = indices.find_first(); 1105 unsigned tensor = merger.tensor(fb); 1106 assert(idx == merger.index(fb)); 1107 auto iteratorTypes = op.iterator_types().getValue(); 1108 bool isReduction = isReductionIterator(iteratorTypes[idx]); 1109 bool isSparse = merger.isDim(fb, Dim::kSparse); 1110 bool isVector = !codegen.sparseOut && 1111 isVectorFor(codegen, isInner, isSparse) && 1112 denseUnitStrides(merger, op, idx); 1113 bool isParallel = 1114 !codegen.sparseOut && 1115 isParallelFor(codegen, isOuter, isReduction, isSparse, isVector); 1116 1117 // Prepare vector length. 1118 if (isVector) 1119 codegen.curVecLength = codegen.options.vectorLength; 1120 1121 // Loop bounds and increment. 1122 Location loc = op.getLoc(); 1123 Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx]; 1124 Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx]; 1125 Value step = constantIndex(rewriter, loc, codegen.curVecLength); 1126 1127 // Emit a parallel loop. 1128 if (isParallel) { 1129 assert(!isVector); 1130 scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step); 1131 if (isSparse) 1132 codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0]; 1133 else 1134 codegen.loops[idx] = parOp.getInductionVars()[0]; 1135 rewriter.setInsertionPointToStart(parOp.getBody()); 1136 return parOp; 1137 } 1138 1139 // Emit a sequential or vector loop. 1140 SmallVector<Value, 4> operands; 1141 if (codegen.redVal) { 1142 // In a vector loop, bring reduction into SIMD form, if not already. 1143 if (isVector && !codegen.redVal.getType().isa<VectorType>()) { 1144 VectorType vtp = vectorType(codegen, codegen.redVal.getType()); 1145 Value vred = genVectorReducInit(codegen, rewriter, loc, vtp); 1146 updateReduc(merger, codegen, vred); 1147 } 1148 operands.push_back(codegen.redVal); 1149 } 1150 if (codegen.expValues) 1151 operands.push_back(codegen.expCount); 1152 scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands); 1153 if (codegen.redVal) 1154 updateReduc(merger, codegen, forOp.getRegionIterArgs().front()); 1155 if (codegen.expValues) 1156 codegen.expCount = forOp.getRegionIterArgs().back(); 1157 // Assign induction variable to sparse or dense index. 1158 Value iv = forOp.getInductionVar(); 1159 if (isSparse) 1160 codegen.pidxs[tensor][idx] = iv; 1161 else 1162 codegen.loops[idx] = iv; 1163 rewriter.setInsertionPointToStart(forOp.getBody()); 1164 // Share vector iteration mask between all subsequent loads/stores. 1165 if (isVector) 1166 codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step); 1167 return forOp; 1168 } 1169 1170 /// Emit a while-loop for co-iteration over multiple indices. 1171 static Operation *genWhile(Merger &merger, CodeGen &codegen, 1172 PatternRewriter &rewriter, linalg::GenericOp op, 1173 unsigned idx, bool needsUniv, 1174 BitVector &indices) { 1175 SmallVector<Type, 4> types; 1176 SmallVector<Value, 4> operands; 1177 // Construct the while-loop with a parameter for each index. 1178 Type indexType = rewriter.getIndexType(); 1179 for (unsigned b = 0, be = indices.size(); b < be; b++) { 1180 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1181 unsigned tensor = merger.tensor(b); 1182 assert(idx == merger.index(b)); 1183 types.push_back(indexType); 1184 operands.push_back(codegen.pidxs[tensor][idx]); 1185 } 1186 } 1187 if (codegen.redVal) { 1188 types.push_back(codegen.redVal.getType()); 1189 operands.push_back(codegen.redVal); 1190 } 1191 if (codegen.expValues) { 1192 types.push_back(indexType); 1193 operands.push_back(codegen.expCount); 1194 } 1195 if (needsUniv) { 1196 types.push_back(indexType); 1197 operands.push_back(codegen.loops[idx]); 1198 } 1199 assert(types.size() == operands.size()); 1200 Location loc = op.getLoc(); 1201 scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands); 1202 1203 SmallVector<Location> locs(types.size(), loc); 1204 Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, locs); 1205 Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, types, locs); 1206 1207 // Build the "before" region, which effectively consists 1208 // of a conjunction of "i < upper" tests on all induction. 1209 rewriter.setInsertionPointToStart(&whileOp.getBefore().front()); 1210 Value cond; 1211 unsigned o = 0; 1212 for (unsigned b = 0, be = indices.size(); b < be; b++) { 1213 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1214 unsigned tensor = merger.tensor(b); 1215 assert(idx == merger.index(b)); 1216 Value op1 = before->getArgument(o); 1217 Value op2 = codegen.highs[tensor][idx]; 1218 Value opc = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 1219 op1, op2); 1220 cond = cond ? rewriter.create<arith::AndIOp>(loc, cond, opc) : opc; 1221 codegen.pidxs[tensor][idx] = after->getArgument(o++); 1222 } 1223 } 1224 if (codegen.redVal) 1225 updateReduc(merger, codegen, after->getArgument(o++)); 1226 if (codegen.expValues) 1227 codegen.expCount = after->getArgument(o++); 1228 if (needsUniv) 1229 codegen.loops[idx] = after->getArgument(o++); 1230 assert(o == operands.size()); 1231 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 1232 rewriter.setInsertionPointToStart(&whileOp.getAfter().front()); 1233 return whileOp; 1234 } 1235 1236 /// Generates a for-loop or a while-loop, depending on whether it implements 1237 /// singleton iteration or co-iteration over the given conjunction. 1238 static Operation *genLoop(Merger &merger, CodeGen &codegen, 1239 PatternRewriter &rewriter, linalg::GenericOp op, 1240 std::vector<unsigned> &topSort, unsigned at, 1241 bool needsUniv, BitVector &indices) { 1242 unsigned idx = topSort[at]; 1243 if (indices.count() == 1) { 1244 bool isOuter = at == 0; 1245 bool isInner = at == topSort.size() - 1; 1246 return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx, 1247 indices); 1248 } 1249 return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices); 1250 } 1251 1252 /// Generates the local variables for this loop, consisting of the sparse 1253 /// indices, restored universal dense index, and dense positions. 1254 static void genLocals(Merger &merger, CodeGen &codegen, 1255 PatternRewriter &rewriter, linalg::GenericOp op, 1256 std::vector<unsigned> &topSort, unsigned at, 1257 bool needsUniv, BitVector &locals) { 1258 Location loc = op.getLoc(); 1259 unsigned idx = topSort[at]; 1260 1261 // Initialize sparse indices. 1262 Value min; 1263 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1264 if (locals[b] && merger.isDim(b, Dim::kSparse)) { 1265 unsigned tensor = merger.tensor(b); 1266 assert(idx == merger.index(b)); 1267 Value ptr = codegen.indices[tensor][idx]; 1268 Value s = codegen.pidxs[tensor][idx]; 1269 Value load = genLoad(codegen, rewriter, loc, ptr, s); 1270 codegen.idxs[tensor][idx] = load; 1271 if (!needsUniv) { 1272 if (min) { 1273 Value cmp = rewriter.create<arith::CmpIOp>( 1274 loc, arith::CmpIPredicate::ult, load, min); 1275 min = rewriter.create<arith::SelectOp>(loc, cmp, load, min); 1276 } else { 1277 min = load; 1278 } 1279 } 1280 } 1281 } 1282 1283 // Merge dense universal index over minimum. 1284 if (min) { 1285 assert(!needsUniv); 1286 codegen.loops[idx] = min; 1287 } 1288 1289 // Initialize dense positions. Note that we generate dense indices of the 1290 // output tensor unconditionally, since they may not appear in the lattice, 1291 // but may be needed for linearized codegen. 1292 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1293 if ((locals[b] || merger.isOutTensor(b, idx)) && 1294 merger.isDim(b, Dim::kDense)) { 1295 unsigned tensor = merger.tensor(b); 1296 assert(idx == merger.index(b)); 1297 unsigned pat = at; 1298 for (; pat != 0; pat--) 1299 if (codegen.pidxs[tensor][topSort[pat - 1]]) 1300 break; 1301 Value p = (pat == 0) ? constantIndex(rewriter, loc, 0) 1302 : codegen.pidxs[tensor][topSort[pat - 1]]; 1303 codegen.pidxs[tensor][idx] = genAddress( 1304 codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]); 1305 } 1306 } 1307 1308 // Move the insertion indices in lexicographic index order. During access 1309 // pattern expansion, we can skip setting the innermost dimension. 1310 if (codegen.sparseOut && !codegen.expValues) { 1311 Value pos = constantIndex(rewriter, loc, at); 1312 rewriter.create<memref::StoreOp>(loc, codegen.loops[idx], codegen.lexIdx, 1313 pos); 1314 } 1315 } 1316 1317 /// Generates the induction structure for a while-loop. 1318 static void genWhileInduction(Merger &merger, CodeGen &codegen, 1319 PatternRewriter &rewriter, linalg::GenericOp op, 1320 unsigned idx, bool needsUniv, 1321 BitVector &induction, 1322 scf::WhileOp whileOp) { 1323 Location loc = op.getLoc(); 1324 // Finalize each else branch of all if statements. 1325 if (codegen.redVal || codegen.expValues) { 1326 while (auto ifOp = dyn_cast_or_null<scf::IfOp>( 1327 rewriter.getInsertionBlock()->getParentOp())) { 1328 unsigned y = 0; 1329 SmallVector<Value, 4> yields; 1330 if (codegen.redVal) { 1331 yields.push_back(codegen.redVal); 1332 updateReduc(merger, codegen, ifOp.getResult(y++)); 1333 } 1334 if (codegen.expValues) { 1335 yields.push_back(codegen.expCount); 1336 codegen.expCount = ifOp->getResult(y++); 1337 } 1338 assert(y == yields.size()); 1339 rewriter.create<scf::YieldOp>(loc, yields); 1340 rewriter.setInsertionPointAfter(ifOp); 1341 } 1342 } 1343 rewriter.setInsertionPointToEnd(&whileOp.getAfter().front()); 1344 // Finalize the induction. Note that the induction could be performed 1345 // in the individual if-branches to avoid re-evaluating the conditions. 1346 // However, that would result in a rather elaborate forest of yield 1347 // instructions during code generation. Moreover, performing the induction 1348 // after the if-statements more closely resembles code generated by TACO. 1349 unsigned o = 0; 1350 SmallVector<Value, 4> operands; 1351 Value one = constantIndex(rewriter, loc, 1); 1352 for (unsigned b = 0, be = induction.size(); b < be; b++) { 1353 if (induction[b] && merger.isDim(b, Dim::kSparse)) { 1354 unsigned tensor = merger.tensor(b); 1355 assert(idx == merger.index(b)); 1356 Value op1 = codegen.idxs[tensor][idx]; 1357 Value op2 = codegen.loops[idx]; 1358 Value op3 = codegen.pidxs[tensor][idx]; 1359 Value cmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 1360 op1, op2); 1361 Value add = rewriter.create<arith::AddIOp>(loc, op3, one); 1362 operands.push_back(rewriter.create<arith::SelectOp>(loc, cmp, add, op3)); 1363 codegen.pidxs[tensor][idx] = whileOp->getResult(o++); 1364 } 1365 } 1366 if (codegen.redVal) { 1367 operands.push_back(codegen.redVal); 1368 updateReduc(merger, codegen, whileOp->getResult(o++)); 1369 } 1370 if (codegen.expValues) { 1371 operands.push_back(codegen.expCount); 1372 codegen.expCount = whileOp->getResult(o++); 1373 } 1374 if (needsUniv) { 1375 operands.push_back( 1376 rewriter.create<arith::AddIOp>(loc, codegen.loops[idx], one)); 1377 codegen.loops[idx] = whileOp->getResult(o++); 1378 } 1379 assert(o == operands.size()); 1380 rewriter.create<scf::YieldOp>(loc, operands); 1381 rewriter.setInsertionPointAfter(whileOp); 1382 } 1383 1384 /// Generates the induction structure for a for-loop. 1385 static void genForInduction(Merger &merger, CodeGen &codegen, 1386 PatternRewriter &rewriter, linalg::GenericOp op, 1387 Operation *loop) { 1388 Location loc = op.getLoc(); 1389 unsigned o = 0; 1390 SmallVector<Value, 4> operands; 1391 if (codegen.redVal) { 1392 operands.push_back(codegen.redVal); 1393 updateReduc(merger, codegen, loop->getResult(o++)); 1394 } 1395 if (codegen.expValues) { 1396 operands.push_back(codegen.expCount); 1397 codegen.expCount = loop->getResult(o++); 1398 } 1399 assert(o == operands.size()); 1400 if (o > 0) 1401 rewriter.create<scf::YieldOp>(loc, operands); 1402 rewriter.setInsertionPointAfter(loop); 1403 } 1404 1405 /// Generates a single if-statement within a while-loop. 1406 static scf::IfOp genIf(Merger &merger, CodeGen &codegen, 1407 PatternRewriter &rewriter, linalg::GenericOp op, 1408 unsigned idx, BitVector &conditions) { 1409 Location loc = op.getLoc(); 1410 SmallVector<Type, 4> types; 1411 Value cond; 1412 for (unsigned b = 0, be = conditions.size(); b < be; b++) { 1413 if (conditions[b]) { 1414 unsigned tensor = merger.tensor(b); 1415 assert(idx == merger.index(b)); 1416 Value clause; 1417 if (merger.isDim(b, Dim::kSparse)) { 1418 Value op1 = codegen.idxs[tensor][idx]; 1419 Value op2 = codegen.loops[idx]; 1420 clause = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 1421 op1, op2); 1422 } else { 1423 clause = constantI1(rewriter, loc, true); 1424 } 1425 cond = cond ? rewriter.create<arith::AndIOp>(loc, cond, clause) : clause; 1426 } 1427 } 1428 if (codegen.redVal) 1429 types.push_back(codegen.redVal.getType()); 1430 if (codegen.expValues) 1431 types.push_back(rewriter.getIndexType()); 1432 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, types, cond, /*else=*/true); 1433 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); 1434 return ifOp; 1435 } 1436 1437 /// Generates end of true branch of if-statement within a while-loop. 1438 static void endIf(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1439 linalg::GenericOp op, scf::IfOp ifOp, Operation *loop, 1440 Value redInput, Value cntInput) { 1441 SmallVector<Value, 4> operands; 1442 if (codegen.redVal) { 1443 operands.push_back(codegen.redVal); 1444 updateReduc(merger, codegen, redInput); 1445 } 1446 if (codegen.expValues) { 1447 operands.push_back(codegen.expCount); 1448 codegen.expCount = cntInput; 1449 } 1450 if (!operands.empty()) 1451 rewriter.create<scf::YieldOp>(op.getLoc(), operands); 1452 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); 1453 } 1454 1455 //===----------------------------------------------------------------------===// 1456 // Sparse compiler synthesis methods (loop sequence). 1457 //===----------------------------------------------------------------------===// 1458 1459 /// Starts a loop sequence at given level. Returns true if 1460 /// the universal loop index must be maintained at this level. 1461 static bool startLoopSeq(Merger &merger, CodeGen &codegen, 1462 PatternRewriter &rewriter, linalg::GenericOp op, 1463 std::vector<unsigned> &topSort, unsigned exp, 1464 unsigned at, unsigned idx, unsigned ldx, 1465 unsigned lts) { 1466 assert(codegen.curVecLength == 1); 1467 assert(!codegen.loops[idx]); 1468 // Emit invariants at this loop sequence level. 1469 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*atStart=*/true); 1470 // Emit access pattern expansion for sparse tensor output. 1471 genExpansion(merger, codegen, rewriter, op, at, /*atStart=*/true); 1472 // Emit further intitialization at this loop sequence level. 1473 unsigned l0 = merger.set(lts)[0]; 1474 bool needsUniv = 1475 genInit(merger, codegen, rewriter, op, topSort, at, merger.lat(l0).bits); 1476 // Maintain the universal index only if it is actually 1477 // consumed by a subsequent lattice point. 1478 if (needsUniv) { 1479 unsigned lsize = merger.set(lts).size(); 1480 for (unsigned i = 1; i < lsize; i++) { 1481 unsigned li = merger.set(lts)[i]; 1482 if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) 1483 return true; 1484 } 1485 } 1486 return false; 1487 } 1488 1489 /// Starts a single loop in current sequence. 1490 static Operation *startLoop(Merger &merger, CodeGen &codegen, 1491 PatternRewriter &rewriter, linalg::GenericOp op, 1492 std::vector<unsigned> &topSort, unsigned at, 1493 unsigned li, bool needsUniv) { 1494 assert(codegen.curVecLength == 1); 1495 // Emit the for/while-loop control. 1496 Operation *loop = genLoop(merger, codegen, rewriter, op, topSort, at, 1497 needsUniv, merger.lat(li).simple); 1498 // Emit the locals for this loop. 1499 genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, 1500 merger.lat(li).bits); 1501 return loop; 1502 } 1503 1504 /// Ends a single loop in current sequence. Returns new values for needsUniv. 1505 static bool endLoop(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1506 linalg::GenericOp op, Operation *loop, unsigned idx, 1507 unsigned li, bool needsUniv) { 1508 codegen.curVecLength = 1; 1509 // End a while-loop. 1510 if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) { 1511 genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv, 1512 merger.lat(li).bits, whileOp); 1513 return needsUniv; 1514 } 1515 // End a for-loop. 1516 genForInduction(merger, codegen, rewriter, op, loop); 1517 return false; 1518 } 1519 1520 /// Ends a loop sequence at given level. 1521 static void endLoopSeq(Merger &merger, CodeGen &codegen, 1522 PatternRewriter &rewriter, linalg::GenericOp op, 1523 unsigned exp, unsigned at, unsigned idx, unsigned ldx) { 1524 assert(codegen.curVecLength == 1); 1525 codegen.loops[idx] = Value(); 1526 // Bring a pending reduction back from SIMD form when sequence ends. 1527 if (codegen.redVal) 1528 if (auto vtp = codegen.redVal.getType().dyn_cast<VectorType>()) 1529 updateReduc(merger, codegen, 1530 genVectorReducEnd(codegen, rewriter, op.getLoc(), vtp)); 1531 // Unmark bookkeeping of invariants and loop index. 1532 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*atStart=*/false); 1533 // Finalize access pattern expansion for sparse tensor output. 1534 genExpansion(merger, codegen, rewriter, op, at, /*atStart=*/false); 1535 } 1536 1537 /// Recursively generates code while computing iteration lattices in order 1538 /// to manage the complexity of implementing co-iteration over unions 1539 /// and intersections of sparse iterations spaces. 1540 static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1541 linalg::GenericOp op, std::vector<unsigned> &topSort, 1542 unsigned exp, unsigned at) { 1543 // At each leaf, assign remaining tensor (sub)expression to output tensor. 1544 if (at == topSort.size()) { 1545 Value rhs = genExp(merger, codegen, rewriter, op, exp); 1546 genTensorStore(merger, codegen, rewriter, op, rhs); 1547 return; 1548 } 1549 1550 // Construct iteration lattices for current loop index, with L0 at top. 1551 unsigned idx = topSort[at]; 1552 unsigned ldx = at == 0 ? -1u : topSort[at - 1]; 1553 unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx)); 1554 1555 // Start a loop sequence. 1556 bool needsUniv = startLoopSeq(merger, codegen, rewriter, op, topSort, exp, at, 1557 idx, ldx, lts); 1558 1559 // Emit a loop for every lattice point L0 >= Li in this loop sequence. 1560 unsigned lsize = merger.set(lts).size(); 1561 for (unsigned i = 0; i < lsize; i++) { 1562 // Start a loop. 1563 unsigned li = merger.set(lts)[i]; 1564 Operation *loop = 1565 startLoop(merger, codegen, rewriter, op, topSort, at, li, needsUniv); 1566 1567 // Visit all lattices points with Li >= Lj to generate the 1568 // loop-body, possibly with if statements for coiteration. 1569 Value redInput = codegen.redVal; 1570 Value cntInput = codegen.expCount; 1571 bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr; 1572 for (unsigned j = 0; j < lsize; j++) { 1573 unsigned lj = merger.set(lts)[j]; 1574 unsigned ej = merger.lat(lj).exp; 1575 if (li == lj || merger.latGT(li, lj)) { 1576 // Recurse into body of each branch. 1577 if (isWhile) { 1578 scf::IfOp ifOp = 1579 genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple); 1580 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1581 endIf(merger, codegen, rewriter, op, ifOp, loop, redInput, cntInput); 1582 } else { 1583 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1584 } 1585 } 1586 } 1587 1588 // End a loop. 1589 needsUniv = 1590 endLoop(merger, codegen, rewriter, op, loop, idx, li, needsUniv); 1591 } 1592 1593 // End a loop sequence. 1594 endLoopSeq(merger, codegen, rewriter, op, exp, at, idx, ldx); 1595 } 1596 1597 /// Converts the result computed by the sparse kernel into the required form. 1598 static void genResult(Merger &merger, CodeGen &codegen, 1599 PatternRewriter &rewriter, linalg::GenericOp op) { 1600 OpOperand *lhs = op.getOutputOperand(0); 1601 Type resType = lhs->get().getType(); 1602 if (getSparseTensorEncoding(resType)) { 1603 // The sparse tensor rematerializes from the original sparse tensor's 1604 // underlying sparse storage format. 1605 rewriter.replaceOpWithNewOp<LoadOp>(op, resType, lhs->get(), 1606 codegen.sparseOut == lhs); 1607 } else { 1608 // To rematerialize an non-annotated tensor, simply load it 1609 // from the bufferized value. 1610 Value val = codegen.buffers.back(); // value array 1611 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val); 1612 } 1613 } 1614 1615 //===----------------------------------------------------------------------===// 1616 // Sparse compiler rewriting methods. 1617 //===----------------------------------------------------------------------===// 1618 1619 namespace { 1620 1621 /// Sparse rewriting rule for generic Lingalg operation. 1622 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> { 1623 public: 1624 GenericOpSparsifier(MLIRContext *context, SparsificationOptions o) 1625 : OpRewritePattern<linalg::GenericOp>(context), options(o) {} 1626 1627 LogicalResult matchAndRewrite(linalg::GenericOp op, 1628 PatternRewriter &rewriter) const override { 1629 // Detects sparse annotations and translate the per-dimension sparsity 1630 // information for all tensors to loop indices in the kernel. 1631 assert(op.getNumOutputs() == 1); 1632 unsigned numTensors = op.getNumInputsAndOutputs(); 1633 unsigned numLoops = op.iterator_types().getValue().size(); 1634 Merger merger(numTensors, numLoops); 1635 if (!findSparseAnnotations(merger, op)) 1636 return failure(); 1637 1638 // Computes a topologically sorted iteration graph to ensure 1639 // tensors are visited in natural index order. Fails on cycles. 1640 // This assumes that higher-level passes have already put the 1641 // tensors in each tensor expression in a feasible order. 1642 std::vector<unsigned> topSort; 1643 if (!computeIterationGraph(merger, op, topSort, 1644 SortMask::kIncludeUndef | 1645 SortMask::kIncludeDense) && 1646 !computeIterationGraph(merger, op, topSort, SortMask::kIncludeUndef) && 1647 !computeIterationGraph(merger, op, topSort, SortMask::kIncludeDense) && 1648 !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly)) 1649 return failure(); 1650 1651 // Builds the tensor expression for the Linalg operation in SSA form. 1652 Optional<unsigned> optExp = merger.buildTensorExpFromLinalg(op); 1653 if (!optExp.hasValue()) 1654 return failure(); 1655 unsigned exp = optExp.getValue(); 1656 1657 // Rejects an inadmissable tensor expression. 1658 OpOperand *sparseOut = nullptr; 1659 unsigned outerParNest = 0; 1660 if (!isAdmissableTensorExp(merger, op, topSort, exp, &sparseOut, 1661 outerParNest)) 1662 return failure(); 1663 1664 // Recursively generates code. 1665 merger.setHasSparseOut(sparseOut != nullptr); 1666 CodeGen codegen(options, numTensors, numLoops, sparseOut, outerParNest); 1667 genBuffers(merger, codegen, rewriter, op); 1668 genStmt(merger, codegen, rewriter, op, topSort, exp, 0); 1669 genResult(merger, codegen, rewriter, op); 1670 return success(); 1671 } 1672 1673 private: 1674 /// Options to control sparse code generation. 1675 SparsificationOptions options; 1676 }; 1677 1678 } // namespace 1679 1680 /// Populates the given patterns list with rewriting rules required for 1681 /// the sparsification of linear algebra operations. 1682 void mlir::populateSparsificationPatterns( 1683 RewritePatternSet &patterns, const SparsificationOptions &options) { 1684 patterns.add<GenericOpSparsifier>(patterns.getContext(), options); 1685 } 1686