1 //===- Sparsification.cpp - Implementation of sparsification --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements converting sparse tensor types to actual sparse code. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "CodegenUtils.h" 14 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 18 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 19 #include "mlir/Dialect/Func/IR/FuncOps.h" 20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 21 #include "mlir/Dialect/Linalg/IR/Linalg.h" 22 #include "mlir/Dialect/Linalg/Utils/Utils.h" 23 #include "mlir/Dialect/MemRef/IR/MemRef.h" 24 #include "mlir/Dialect/SCF/IR/SCF.h" 25 #include "mlir/Dialect/SCF/Transforms/Transforms.h" 26 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 27 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 28 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 29 #include "mlir/Dialect/Vector/IR/VectorOps.h" 30 #include "mlir/IR/Matchers.h" 31 #include "mlir/IR/TensorEncoding.h" 32 #include "llvm/ADT/SmallBitVector.h" 33 34 using namespace mlir; 35 using namespace mlir::sparse_tensor; 36 37 //===----------------------------------------------------------------------===// 38 // Declarations of data structures. 39 //===----------------------------------------------------------------------===// 40 41 namespace { 42 43 // Iteration graph sorting. 44 enum SortMask { 45 kSparseOnly = 0x0, 46 kIncludeDense = 0x1, 47 kIncludeUndef = 0x2, 48 kIncludeAll = 0x3 49 }; 50 51 // Reduction kinds. 52 enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor }; 53 54 // Code generation. 55 struct CodeGen { 56 CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops, 57 OpOperand *op, unsigned nest) 58 : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors), 59 pointers(numTensors, std::vector<Value>(numLoops)), 60 indices(numTensors, std::vector<Value>(numLoops)), 61 highs(numTensors, std::vector<Value>(numLoops)), 62 pidxs(numTensors, std::vector<Value>(numLoops)), 63 idxs(numTensors, std::vector<Value>(numLoops)), redVal(), sparseOut(op), 64 outerParNest(nest), lexIdx(), lexVal(), expValues(), expFilled(), 65 expAdded(), expCount(), curVecMask() {} 66 /// Sparsification options. 67 SparsificationOptions options; 68 /// Universal dense indices and upper bounds (by index). The loops array 69 /// is updated with the value of the universal dense index in the current 70 /// loop. The sizes array is set once with the inferred dimension sizes. 71 std::vector<Value> loops; 72 std::vector<Value> sizes; 73 /// Buffers for storing dense and sparse numerical values (by tensor). 74 /// This array is set once during bufferization of all tensors. 75 std::vector<Value> buffers; 76 /// Sparse storage schemes (1-D): pointers and indices (by tensor and index). 77 /// This array is set once during bufferization of all sparse tensors. 78 std::vector<std::vector<Value>> pointers; 79 std::vector<std::vector<Value>> indices; 80 /// Sparse iteration information (by tensor and index). These arrays 81 /// are updated to remain current within the current loop. 82 std::vector<std::vector<Value>> highs; 83 std::vector<std::vector<Value>> pidxs; 84 std::vector<std::vector<Value>> idxs; 85 /// Current reduction, updated during code generation. When indices of a 86 /// reduction are exhausted, all inner loops can use a scalarized reduction. 87 unsigned redExp = -1u; 88 Value redVal; 89 Reduction redKind = kNoReduc; 90 // Sparse tensor as output. Implemented either through direct injective 91 // insertion in lexicographic index order (where indices are updated 92 // in the temporary array `lexIdx`) or through access pattern expansion 93 // in the innermost loop nest (`expValues` through `expCount`). 94 OpOperand *sparseOut; 95 unsigned outerParNest; 96 Value lexIdx; 97 Value lexVal; 98 Value expValues; 99 Value expFilled; 100 Value expAdded; 101 Value expCount; 102 // Current vector length and mask. 103 unsigned curVecLength = 1; 104 Value curVecMask; 105 }; 106 107 } // namespace 108 109 //===----------------------------------------------------------------------===// 110 // Sparse compiler analysis methods. 111 //===----------------------------------------------------------------------===// 112 113 /// Helper method to construct a permuted dimension ordering 114 /// that adheres to the given topological sort. 115 static AffineMap permute(MLIRContext *context, AffineMap m, 116 std::vector<unsigned> &topSort) { 117 unsigned sz = topSort.size(); 118 SmallVector<unsigned, 4> perm(sz); 119 for (unsigned i = 0; i < sz; i++) 120 perm[i] = m.getPermutedPosition(topSort[i]); 121 return AffineMap::getPermutationMap(perm, context); 122 } 123 124 /// Helper method to apply dimension ordering permutation. 125 static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d) { 126 if (enc) { 127 auto order = enc.getDimOrdering(); 128 if (order) { 129 assert(order.isPermutation()); 130 return order.getDimPosition(d); 131 } 132 } 133 return d; 134 } 135 136 /// Helper method to translate dim level type to internal representation. 137 static Dim toDim(const SparseTensorEncodingAttr &enc, unsigned d) { 138 if (enc) { 139 SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d]; 140 if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed) 141 return Dim::kSparse; 142 if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton) 143 return Dim::kSingle; 144 } 145 return Dim::kDense; 146 } 147 148 /// Helper method to inspect affine expressions. Rejects cases where the 149 /// same index is used more than once. Also rejects affine expressions 150 /// that are not a direct index for annotated tensors. 151 // TODO: accept more affine cases for sparse tensors 152 static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a, Dim dim, 153 bool isDense) { 154 switch (a.getKind()) { 155 case AffineExprKind::DimId: { 156 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 157 if (!merger.isDim(tensor, idx, Dim::kUndef)) 158 return false; // used more than once 159 merger.setDim(tensor, idx, dim); 160 return true; 161 } 162 case AffineExprKind::Add: 163 case AffineExprKind::Mul: { 164 if (!isDense) 165 return false; 166 auto binOp = a.cast<AffineBinaryOpExpr>(); 167 return findAffine(merger, tensor, binOp.getLHS(), dim, isDense) && 168 findAffine(merger, tensor, binOp.getRHS(), dim, isDense); 169 } 170 case AffineExprKind::Constant: 171 return isDense; 172 default: 173 return false; 174 } 175 } 176 177 /// Helper method to inspect sparse encodings in the tensor types. 178 /// Fills the per-dimension sparsity information for all tensors. 179 /// Returns true if the sparse annotations and affine subscript 180 /// expressions of all tensors are admissable. Returns false if 181 /// no annotations are found or inadmissable constructs occur. 182 static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) { 183 bool annotated = false; 184 for (OpOperand *t : op.getInputAndOutputOperands()) { 185 auto map = op.getTiedIndexingMap(t); 186 auto enc = getSparseTensorEncoding(t->get().getType()); 187 if (enc) 188 annotated = true; 189 assert(map.getNumResults() == op.getRank(t)); 190 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 191 unsigned tensor = t->getOperandNumber(); 192 AffineExpr a = map.getResult(perm(enc, d)); 193 if (!findAffine(merger, tensor, a, toDim(enc, d), !enc)) 194 return false; // inadmissable affine expression 195 } 196 } 197 return annotated; 198 } 199 200 /// A DFS helper to compute a topological sort. Note that recursion is 201 /// bounded by the number of implicit loops, which is always small. 202 /// Returns false when a cycle is detected. 203 static bool topSortDFS(unsigned i, std::vector<unsigned> &visit, 204 std::vector<unsigned> &topSort, 205 std::vector<std::vector<bool>> &adjM) { 206 if (visit[i] != 0) 207 return visit[i] != 1; // 1 denotes cycle! 208 visit[i] = 1; 209 for (unsigned j = 0, e = visit.size(); j < e; j++) 210 if (adjM[i][j]) 211 if (!topSortDFS(j, visit, topSort, adjM)) 212 return false; 213 visit[i] = 2; 214 topSort.push_back(i); 215 return true; 216 } 217 218 /// Helper method to add all constraints from the indices in one affine 219 /// expression before all indices in the other affine expression. For 220 /// example i0+i1 < i2+i3+1 yields i0<i2, i0<i3, i1<i2, and i1<i3. 221 static void addAffineOrderings(std::vector<std::vector<bool>> &adjM, 222 AffineExpr a, AffineExpr b, unsigned fidx) { 223 switch (a.getKind()) { 224 case AffineExprKind::DimId: { 225 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 226 if (b) 227 addAffineOrderings(adjM, b, AffineExpr(), idx); 228 else 229 adjM[fidx][idx] = true; 230 break; 231 } 232 case AffineExprKind::Add: 233 case AffineExprKind::Mul: { 234 auto binOp = a.cast<AffineBinaryOpExpr>(); 235 addAffineOrderings(adjM, binOp.getLHS(), b, fidx); 236 addAffineOrderings(adjM, binOp.getRHS(), b, fidx); 237 break; 238 } 239 default: 240 break; 241 } 242 } 243 244 /// Computes a topologically sorted iteration graph for the linalg operation. 245 /// Ensures all tensors are visited in natural index order. This is essential 246 /// for sparse storage formats since these only support access along fixed 247 /// dimensions. Even for dense storage formats, however, the natural index 248 /// order yields innermost unit-stride access with better spatial locality. 249 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op, 250 std::vector<unsigned> &topSort, unsigned mask, 251 OpOperand *skip = nullptr) { 252 // Set up an n x n from/to adjacency matrix of the iteration graph 253 // for the implicit loop indices i_0 .. i_n-1. 254 unsigned n = op.getNumLoops(); 255 std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false)); 256 257 // Iterate over the indexing maps of every tensor in the tensor expression. 258 for (OpOperand *t : op.getInputAndOutputOperands()) { 259 // Skip tensor during cycle resolution. 260 if (t == skip) 261 continue; 262 // Get map and encoding. 263 auto map = op.getTiedIndexingMap(t); 264 auto enc = getSparseTensorEncoding(t->get().getType()); 265 assert(map.getNumDims() == n); 266 // Skip dense tensor constraints when not requested. 267 if (!(mask & SortMask::kIncludeDense) && !enc) 268 continue; 269 // Each tensor expression and optional dimension ordering (row-major 270 // by default) puts an ordering constraint on the loop indices. For 271 // example, the tensor expresion A_ijk forces the ordering i < j < k 272 // on the loop indices if no explicit dimension ordering is given. 273 for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) { 274 AffineExpr f = map.getResult(perm(enc, d - 1)); 275 AffineExpr t = map.getResult(perm(enc, d)); 276 addAffineOrderings(adjM, f, t, 0); 277 } 278 // Push unrelated loops into sparse iteration space, so these 279 // will be skipped more often. 280 if (mask & SortMask::kIncludeUndef) { 281 unsigned tensor = t->getOperandNumber(); 282 for (unsigned i = 0; i < n; i++) 283 if (merger.isDim(tensor, i, Dim::kSparse)) 284 for (unsigned j = 0; j < n; j++) 285 if (merger.isDim(tensor, j, Dim::kUndef)) 286 adjM[i][j] = true; 287 } 288 } 289 290 // Topologically sort the iteration graph to determine loop order. 291 // Report failure for a cyclic iteration graph. 292 topSort.clear(); 293 topSort.reserve(n); 294 std::vector<unsigned> visit(n, 0); 295 for (unsigned i = 0; i < n; i++) 296 if (visit[i] == 0) 297 if (!topSortDFS(i, visit, topSort, adjM)) 298 return false; // cycle! 299 std::reverse(std::begin(topSort), std::end(topSort)); 300 return true; 301 } 302 303 /// Returns true if tensor has an in-place annotation. 304 static bool isInPlace(Value val) { 305 if (auto arg = val.dyn_cast<BlockArgument>()) 306 if (auto funcOp = dyn_cast<func::FuncOp>(arg.getOwner()->getParentOp())) 307 if (auto attr = funcOp.getArgAttrOfType<BoolAttr>( 308 arg.getArgNumber(), 309 bufferization::BufferizableOpInterface::kInplaceableAttrName)) 310 return attr.getValue(); 311 return false; 312 } 313 314 /// Returns true if tensor materializes uninitialized into the computation. 315 static bool isMaterializing(Value val) { 316 return val.getDefiningOp<linalg::InitTensorOp>() || 317 val.getDefiningOp<bufferization::AllocTensorOp>(); 318 } 319 320 /// Returns true when the tensor expression is admissable for codegen. 321 /// Since all sparse input tensors are admissable, we just need to check 322 /// whether the out tensor in the tensor expression codegen is admissable. 323 /// Sets `sparseOut` to the tensor and `outerParNest` to the outer injective 324 /// nesting depth when a "truly dynamic" sparse tensor output occurs. 325 static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op, 326 std::vector<unsigned> &topSort, unsigned exp, 327 OpOperand **sparseOut, 328 unsigned &outerParNest) { 329 OpOperand *lhs = op.getOutputOperand(0); 330 unsigned tensor = lhs->getOperandNumber(); 331 auto enc = getSparseTensorEncoding(lhs->get().getType()); 332 // An non-annotated output tensor is assumed dense, and becomes a random 333 // access n-dim memref. Admissable since insertions cannot occur. 334 if (!enc) 335 return true; 336 // An all-dense annotated "sparse" output tensor becomes a linearized random 337 // access 1-dim memref. Also admissable since insertions cannot occur. 338 bool allDense = true; 339 auto iteratorTypes = op.iterator_types().getValue(); 340 unsigned numLoops = iteratorTypes.size(); 341 for (unsigned i = 0; i < numLoops; i++) 342 if (merger.isDim(tensor, i, Dim::kSparse)) { 343 allDense = false; 344 break; 345 } 346 if (allDense) 347 return true; 348 // A tensor expression with a sparse output tensor that changes its values 349 // but not its nonzero structure, an operation called "simply dynamic" in 350 // [Bik96,Ch9], is also admissable without special codegen, provided 351 // the tensor's underlying sparse storage scheme can be modified in-place. 352 if (merger.isSingleCondition(tensor, exp) && isInPlace(lhs->get())) 353 return true; 354 // Accept "truly dynamic" if the output tensor materializes uninitialized 355 // into the computation and insertions occur in lexicographic index order. 356 if (isMaterializing(lhs->get())) { 357 unsigned nest = 0; 358 for (unsigned i = 0; i < numLoops; i++) { 359 if (isReductionIterator(iteratorTypes[topSort[i]])) 360 break; // terminate at first reduction 361 nest++; 362 } 363 // Determine admissable dynamic insertion situations: 364 // (1) fully injective, since there are no reductions, 365 // (2) admissable 1-d expansion in innermost dimension. 366 if (nest >= op.getRank(lhs) - 1) { 367 *sparseOut = lhs; 368 outerParNest = nest; 369 return true; 370 } 371 } 372 return false; 373 } 374 375 //===----------------------------------------------------------------------===// 376 // Sparse compiler synthesis methods (reductions). 377 //===----------------------------------------------------------------------===// 378 379 /// Maps reduction kind to vector::CombiningKind. 380 static vector::CombiningKind getCombiningKind(Reduction kind) { 381 switch (kind) { 382 case kNoReduc: 383 break; 384 case kSum: 385 return vector::CombiningKind::ADD; 386 case kProduct: 387 return vector::CombiningKind::MUL; 388 case kAnd: 389 return vector::CombiningKind::AND; 390 case kOr: 391 return vector::CombiningKind::OR; 392 case kXor: 393 return vector::CombiningKind::XOR; 394 } 395 llvm_unreachable("unknown reduction kind"); 396 } 397 398 /// Maps operation to reduction. 399 static Reduction getReduction(Kind kind) { 400 switch (kind) { 401 case Kind::kAddF: 402 case Kind::kAddC: 403 case Kind::kAddI: 404 case Kind::kSubF: 405 case Kind::kSubC: 406 case Kind::kSubI: 407 return kSum; 408 case Kind::kMulF: 409 case Kind::kMulC: 410 case Kind::kMulI: 411 return kProduct; 412 case Kind::kAndI: 413 return kAnd; 414 case Kind::kOrI: 415 return kOr; 416 case Kind::kXorI: 417 return kXor; 418 default: 419 llvm_unreachable("unexpected reduction operator"); 420 } 421 } 422 423 /// Generates an initial value for a vector reduction, following the scheme 424 /// given in Chapter 5 of "The Software Vectorization Handbook", where the 425 /// initial scalar value is correctly embedded in the vector reduction value, 426 /// and a straightforward horizontal reduction will complete the operation. 427 static Value genVectorReducInit(CodeGen &codegen, OpBuilder &builder, 428 Location loc, VectorType vtp) { 429 Value r = codegen.redVal; 430 switch (codegen.redKind) { 431 case kNoReduc: 432 break; 433 case kSum: 434 case kXor: 435 // Initialize reduction vector to: | 0 | .. | 0 | r | 436 return builder.create<vector::InsertElementOp>( 437 loc, r, constantZero(builder, loc, vtp), 438 constantIndex(builder, loc, 0)); 439 case kProduct: 440 // Initialize reduction vector to: | 1 | .. | 1 | r | 441 return builder.create<vector::InsertElementOp>( 442 loc, r, constantOne(builder, loc, vtp), constantIndex(builder, loc, 0)); 443 case kAnd: 444 case kOr: 445 // Initialize reduction vector to: | r | .. | r | r | 446 return builder.create<vector::BroadcastOp>(loc, vtp, r); 447 } 448 llvm_unreachable("unknown reduction kind"); 449 } 450 451 /// Generates final value for a vector reduction. 452 static Value genVectorReducEnd(CodeGen &codegen, OpBuilder &builder, 453 Location loc, VectorType vtp) { 454 vector::CombiningKind kind = getCombiningKind(codegen.redKind); 455 return builder.create<vector::ReductionOp>(loc, kind, codegen.redVal); 456 } 457 458 /// Updates scalarized reduction value. 459 static void updateReduc(Merger &merger, CodeGen &codegen, Value reduc) { 460 assert(codegen.redKind != kNoReduc); 461 codegen.redVal = merger.exp(codegen.redExp).val = reduc; 462 } 463 464 //===----------------------------------------------------------------------===// 465 // Sparse compiler synthesis methods (statements and expressions). 466 //===----------------------------------------------------------------------===// 467 468 /// Generates buffer for the output tensor. Note that all sparse kernels 469 /// assume that when all elements are written to (viz. x(i) = y(i) * z(i)), 470 /// the output buffer is already initialized to all zeroes and only nonzeroes 471 /// values are computed and written out. For updates (viz. x(i) += y(i) * z(i)), 472 /// only nonzeroes values are used for the updates and no assumption on the 473 /// original contents of the output buffer is necessary. 474 static Value genOutputBuffer(CodeGen &codegen, OpBuilder &builder, 475 linalg::GenericOp op, MemRefType denseTp, 476 ArrayRef<Value> args) { 477 Location loc = op.getLoc(); 478 OpOperand *lhs = op.getOutputOperand(0); 479 Value tensor = lhs->get(); 480 bool isInit = op.isInitTensor(lhs); 481 // An output tensor that is in-place can simply materialize from the buffer 482 // of the tensor that appears in the outs() clause. For updates, this has 483 // the advantage that only the nonzero value are involved in the computation, 484 // keeping the operation O(nnz). In all other cases, we are forced to zero 485 // out the buffer to enforce the assumption above, which may negatively 486 // impact running complexity (viz. O(n^2 + nnz) vs. O(nnz) for matrices). 487 // TODO: use better analysis to avoid zeroing out the buffer? 488 if (isInPlace(tensor)) { 489 Value init = 490 builder.create<bufferization::ToMemrefOp>(loc, denseTp, tensor); 491 if (!isInit) { 492 Value zero = constantZero(builder, loc, denseTp.getElementType()); 493 builder.create<linalg::FillOp>(loc, ValueRange{zero}, ValueRange{init}); 494 } 495 return init; 496 } 497 // By default, a new buffer is allocated which is either set to zero (when 498 // no updates occur or the tensor materializes into this computation) or 499 // initialized to the value of the tensor defined in the outs() clause. 500 // This is always correct (since it enforces all assumptions above) but 501 // may negatively impact running complexity as explained above. 502 Value alloc = builder.create<memref::AllocOp>(loc, denseTp, args); 503 if (!isInit || isMaterializing(tensor)) { 504 Value zero = constantZero(builder, loc, denseTp.getElementType()); 505 builder.create<linalg::FillOp>(loc, ValueRange{zero}, ValueRange{alloc}); 506 } else { 507 Value init = 508 builder.create<bufferization::ToMemrefOp>(loc, denseTp, tensor); 509 builder.create<memref::CopyOp>(loc, init, alloc); 510 } 511 return alloc; 512 } 513 514 /// Local bufferization of all dense and sparse data structures. 515 /// This code enables testing the first prototype sparse compiler. 516 // TODO: replace this with a proliferated bufferization strategy 517 static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder, 518 linalg::GenericOp op) { 519 Location loc = op.getLoc(); 520 assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1); 521 // For every tensor, find lower and upper bound on dimensions, set the 522 // same bounds on loop indices, and obtain dense or sparse buffer(s). 523 SmallVector<Value, 4> args; 524 for (OpOperand *t : op.getInputAndOutputOperands()) { 525 unsigned tensor = t->getOperandNumber(); 526 auto shape = op.getShape(t); 527 auto map = op.getTiedIndexingMap(t); 528 auto enc = getSparseTensorEncoding(t->get().getType()); 529 // Scan all dimensions of current tensor. 530 args.clear(); 531 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 532 AffineExpr a = map.getResult(perm(enc, d)); 533 if (a.getKind() != AffineExprKind::DimId) 534 continue; // compound 535 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 536 // Handle sparse storage schemes. 537 if (merger.isDim(tensor, idx, Dim::kSparse)) { 538 auto dynShape = {ShapedType::kDynamicSize}; 539 auto ptrTp = 540 MemRefType::get(dynShape, getPointerOverheadType(builder, enc)); 541 auto indTp = 542 MemRefType::get(dynShape, getIndexOverheadType(builder, enc)); 543 Value dim = constantIndex(builder, loc, d); 544 // Generate sparse primitives to obtains pointer and indices. 545 codegen.pointers[tensor][idx] = 546 builder.create<ToPointersOp>(loc, ptrTp, t->get(), dim); 547 codegen.indices[tensor][idx] = 548 builder.create<ToIndicesOp>(loc, indTp, t->get(), dim); 549 } 550 // Find upper bound in current dimension. 551 unsigned p = perm(enc, d); 552 Value up = linalg::createOrFoldDimOp(builder, loc, t->get(), p); 553 if (ShapedType::isDynamic(shape[p])) 554 args.push_back(up); 555 assert(codegen.highs[tensor][idx] == nullptr); 556 codegen.sizes[idx] = codegen.highs[tensor][idx] = up; 557 } 558 // Perform the required bufferization. Dense inputs materialize 559 // from the input tensors. Dense outputs need special handling. 560 // Sparse inputs use sparse primitives to obtain the values. 561 // We also accept in-place all-dense annotated "sparse" outputs. 562 Type elementType = getElementTypeOrSelf(t->get().getType()); 563 if (!enc) { 564 // Non-annotated dense tensors. 565 auto denseTp = MemRefType::get(shape, elementType); 566 if (tensor < op.getNumInputs()) 567 codegen.buffers[tensor] = 568 builder.create<bufferization::ToMemrefOp>(loc, denseTp, t->get()); 569 else 570 codegen.buffers[tensor] = 571 genOutputBuffer(codegen, builder, op, denseTp, args); 572 } else if (t == codegen.sparseOut) { 573 // True sparse output needs a lexIdx array. 574 Value rank = constantIndex(builder, loc, op.getRank(t)); 575 auto dynShape = {ShapedType::kDynamicSize}; 576 auto memTp = MemRefType::get(dynShape, builder.getIndexType()); 577 codegen.lexIdx = builder.create<memref::AllocaOp>(loc, memTp, rank); 578 codegen.lexVal = builder.create<memref::AllocaOp>( 579 loc, MemRefType::get({}, elementType)); 580 } else { 581 // Annotated sparse tensors. 582 auto dynShape = {ShapedType::kDynamicSize}; 583 auto sparseTp = MemRefType::get(dynShape, elementType); 584 codegen.buffers[tensor] = 585 builder.create<ToValuesOp>(loc, sparseTp, t->get()); 586 } 587 } 588 } 589 590 /// Constructs vector type. 591 static VectorType vectorType(CodeGen &codegen, Type etp) { 592 unsigned numScalableDims = codegen.options.enableVLAVectorization; 593 return VectorType::get(codegen.curVecLength, etp, numScalableDims); 594 } 595 596 /// Constructs vector type from pointer. 597 static VectorType vectorType(CodeGen &codegen, Value ptr) { 598 return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType()); 599 } 600 601 /// Constructs vector iteration mask. 602 static Value genVectorMask(CodeGen &codegen, OpBuilder &builder, Value iv, 603 Value lo, Value hi, Value step) { 604 Location loc = iv.getLoc(); 605 VectorType mtp = vectorType(codegen, builder.getI1Type()); 606 // Special case if the vector length evenly divides the trip count (for 607 // example, "for i = 0, 128, 16"). A constant all-true mask is generated 608 // so that all subsequent masked memory operations are immediately folded 609 // into unconditional memory operations. 610 IntegerAttr loInt, hiInt, stepInt; 611 if (matchPattern(lo, m_Constant(&loInt)) && 612 matchPattern(hi, m_Constant(&hiInt)) && 613 matchPattern(step, m_Constant(&stepInt))) { 614 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) 615 return builder.create<vector::BroadcastOp>( 616 loc, mtp, constantI1(builder, loc, true)); 617 } 618 // Otherwise, generate a vector mask that avoids overrunning the upperbound 619 // during vector execution. Here we rely on subsequent loop optimizations to 620 // avoid executing the mask in all iterations, for example, by splitting the 621 // loop into an unconditional vector loop and a scalar cleanup loop. 622 auto minMap = AffineMap::get( 623 /*dimCount=*/2, /*symbolCount=*/1, 624 {builder.getAffineSymbolExpr(0), 625 builder.getAffineDimExpr(0) - builder.getAffineDimExpr(1)}, 626 builder.getContext()); 627 Value end = 628 builder.createOrFold<AffineMinOp>(loc, minMap, ValueRange{hi, iv, step}); 629 return builder.create<vector::CreateMaskOp>(loc, mtp, end); 630 } 631 632 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi]. 633 static Value genVectorLoad(CodeGen &codegen, OpBuilder &builder, Value ptr, 634 ArrayRef<Value> args) { 635 Location loc = ptr.getLoc(); 636 VectorType vtp = vectorType(codegen, ptr); 637 Value pass = constantZero(builder, loc, vtp); 638 if (args.back().getType().isa<VectorType>()) { 639 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 640 Value indexVec = args.back(); 641 scalarArgs.back() = constantIndex(builder, loc, 0); 642 return builder.create<vector::GatherOp>(loc, vtp, ptr, scalarArgs, indexVec, 643 codegen.curVecMask, pass); 644 } 645 return builder.create<vector::MaskedLoadOp>(loc, vtp, ptr, args, 646 codegen.curVecMask, pass); 647 } 648 649 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs. 650 static void genVectorStore(CodeGen &codegen, OpBuilder &builder, Value rhs, 651 Value ptr, ArrayRef<Value> args) { 652 Location loc = ptr.getLoc(); 653 if (args.back().getType().isa<VectorType>()) { 654 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 655 Value indexVec = args.back(); 656 scalarArgs.back() = constantIndex(builder, loc, 0); 657 builder.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, 658 codegen.curVecMask, rhs); 659 return; 660 } 661 builder.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask, 662 rhs); 663 } 664 665 /// Generates a vectorized invariant. Here we rely on subsequent loop 666 /// optimizations to hoist the invariant broadcast out of the vector loop. 667 static Value genVectorInvariantValue(CodeGen &codegen, OpBuilder &builder, 668 Value val) { 669 VectorType vtp = vectorType(codegen, val.getType()); 670 return builder.create<vector::BroadcastOp>(val.getLoc(), vtp, val); 671 } 672 673 /// Generates an affine expression. 674 // 675 // TODO: generalize for sparse tensor subscripts 676 // 677 static Value genAffine(CodeGen &codegen, OpBuilder &builder, AffineExpr a, 678 Location loc) { 679 switch (a.getKind()) { 680 case AffineExprKind::DimId: { 681 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 682 return codegen.loops[idx]; // universal dense index 683 } 684 case AffineExprKind::Add: { 685 auto binOp = a.cast<AffineBinaryOpExpr>(); 686 return builder.create<arith::AddIOp>( 687 loc, genAffine(codegen, builder, binOp.getLHS(), loc), 688 genAffine(codegen, builder, binOp.getRHS(), loc)); 689 } 690 case AffineExprKind::Mul: { 691 auto binOp = a.cast<AffineBinaryOpExpr>(); 692 return builder.create<arith::MulIOp>( 693 loc, genAffine(codegen, builder, binOp.getLHS(), loc), 694 genAffine(codegen, builder, binOp.getRHS(), loc)); 695 } 696 case AffineExprKind::Constant: { 697 int64_t c = a.cast<AffineConstantExpr>().getValue(); 698 return constantIndex(builder, loc, c); 699 } 700 default: 701 llvm_unreachable("unexpected affine subscript"); 702 } 703 } 704 705 /// Generates index for load/store on sparse tensor. 706 static Value genIndex(CodeGen &codegen, linalg::GenericOp op, OpOperand *t) { 707 auto map = op.getTiedIndexingMap(t); 708 auto enc = getSparseTensorEncoding(t->get().getType()); 709 AffineExpr a = map.getResult(perm(enc, map.getNumResults() - 1)); 710 assert(a.getKind() == AffineExprKind::DimId); 711 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 712 return codegen.loops[idx]; 713 } 714 715 /// Generates subscript for load/store on a dense or sparse tensor. 716 static Value genSubscript(CodeGen &codegen, OpBuilder &builder, 717 linalg::GenericOp op, OpOperand *t, 718 SmallVector<Value, 4> &args) { 719 unsigned tensor = t->getOperandNumber(); 720 auto map = op.getTiedIndexingMap(t); 721 auto enc = getSparseTensorEncoding(t->get().getType()); 722 unsigned rank = map.getNumResults(); 723 if (enc) { 724 // Note that currently, all sparse subscripts are simple. 725 // TODO: accept affine too? 726 AffineExpr a = map.getResult(perm(enc, rank - 1)); 727 assert(a.getKind() == AffineExprKind::DimId); 728 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 729 assert(codegen.pidxs[tensor][idx] != nullptr); 730 args.push_back(codegen.pidxs[tensor][idx]); // position index 731 } else { 732 for (unsigned d = 0; d < rank; d++) { 733 AffineExpr a = map.getResult(perm(enc, d)); 734 args.push_back(genAffine(codegen, builder, a, op.getLoc())); 735 } 736 } 737 return codegen.buffers[tensor]; 738 } 739 740 /// Generates insertion code to implement dynamic tensor load. 741 static Value genInsertionLoad(CodeGen &codegen, OpBuilder &builder, 742 linalg::GenericOp op, OpOperand *t) { 743 Location loc = op.getLoc(); 744 // Direct lexicographic index order, tensor loads as zero. 745 if (!codegen.expValues) { 746 Type tp = getElementTypeOrSelf(t->get().getType()); 747 return constantZero(builder, loc, tp); 748 } 749 // Load from expanded access pattern. 750 Value index = genIndex(codegen, op, t); 751 return builder.create<memref::LoadOp>(loc, codegen.expValues, index); 752 } 753 754 /// Generates insertion code to implement dynamic tensor store. 755 static void genInsertionStore(CodeGen &codegen, OpBuilder &builder, 756 linalg::GenericOp op, OpOperand *t, Value rhs) { 757 Location loc = op.getLoc(); 758 // Direct insertion in lexicographic index order. 759 if (!codegen.expValues) { 760 builder.create<memref::StoreOp>(loc, rhs, codegen.lexVal); 761 builder.create<LexInsertOp>(loc, t->get(), codegen.lexIdx, codegen.lexVal); 762 return; 763 } 764 // Generates insertion code along expanded access pattern. 765 // if (!expFilled[i]) then 766 // expFilled[i] = true 767 // expAdded[inserts++] = i 768 // endif 769 // values[i] = rhs 770 Value index = genIndex(codegen, op, t); 771 Value fval = constantI1(builder, loc, false); 772 Value tval = constantI1(builder, loc, true); 773 // If statement. 774 Value filled = builder.create<memref::LoadOp>(loc, codegen.expFilled, index); 775 Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 776 filled, fval); 777 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIndexType(), cond, 778 /*else=*/true); 779 // True branch. 780 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 781 builder.create<memref::StoreOp>(loc, tval, codegen.expFilled, index); 782 builder.create<memref::StoreOp>(loc, index, codegen.expAdded, 783 codegen.expCount); 784 Value one = constantIndex(builder, loc, 1); 785 Value add = builder.create<arith::AddIOp>(loc, codegen.expCount, one); 786 builder.create<scf::YieldOp>(loc, add); 787 // False branch. 788 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 789 builder.create<scf::YieldOp>(loc, codegen.expCount); 790 builder.setInsertionPointAfter(ifOp); 791 // Value assignment. 792 codegen.expCount = ifOp.getResult(0); 793 builder.create<memref::StoreOp>(loc, rhs, codegen.expValues, index); 794 } 795 796 /// Generates a load on a dense or sparse tensor. 797 static Value genTensorLoad(Merger &merger, CodeGen &codegen, OpBuilder &builder, 798 linalg::GenericOp op, unsigned exp) { 799 // Test if the load was hoisted to a higher loop nest. 800 Value val = merger.exp(exp).val; 801 if (val) { 802 if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>()) 803 return genVectorInvariantValue(codegen, builder, val); 804 return val; 805 } 806 // Load during insertion. 807 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 808 if (t == codegen.sparseOut) 809 return genInsertionLoad(codegen, builder, op, t); 810 // Actual load. 811 SmallVector<Value, 4> args; 812 Value ptr = genSubscript(codegen, builder, op, t, args); 813 if (codegen.curVecLength > 1) 814 return genVectorLoad(codegen, builder, ptr, args); 815 return builder.create<memref::LoadOp>(op.getLoc(), ptr, args); 816 } 817 818 /// Generates a store on a dense or sparse tensor. 819 static void genTensorStore(Merger &merger, CodeGen &codegen, OpBuilder &builder, 820 linalg::GenericOp op, unsigned exp, Value rhs) { 821 Location loc = op.getLoc(); 822 // Test if this is a scalarized reduction. 823 if (codegen.redVal) { 824 if (codegen.curVecLength > 1) 825 rhs = builder.create<arith::SelectOp>(loc, codegen.curVecMask, rhs, 826 codegen.redVal); 827 updateReduc(merger, codegen, rhs); 828 return; 829 } 830 // Store during insertion. 831 OpOperand *t = op.getOutputOperand(0); 832 if (t == codegen.sparseOut) { 833 if (!rhs) { 834 // Only unary and binary are allowed to return uninitialized rhs 835 // to indicate missing output. 836 assert(merger.exp(exp).kind == kUnary || merger.exp(exp).kind == kBinary); 837 } else { 838 genInsertionStore(codegen, builder, op, t, rhs); 839 } 840 return; 841 } 842 // Actual store. 843 SmallVector<Value, 4> args; 844 Value ptr = genSubscript(codegen, builder, op, t, args); 845 if (codegen.curVecLength > 1) 846 genVectorStore(codegen, builder, rhs, ptr, args); 847 else 848 builder.create<memref::StoreOp>(loc, rhs, ptr, args); 849 } 850 851 /// Generates a pointer/index load from the sparse storage scheme. Narrower 852 /// data types need to be zero extended before casting the value into the 853 /// index type used for looping and indexing. 854 static Value genLoad(CodeGen &codegen, OpBuilder &builder, Location loc, 855 Value ptr, Value s) { 856 // See https://llvm.org/docs/GetElementPtr.html for some background on 857 // the complications described below. 858 if (codegen.curVecLength > 1) { 859 // Since the index vector is used in a subsequent gather/scatter operations, 860 // which effectively defines an unsigned pointer + signed index, we must 861 // zero extend the vector to an index width. For 8-bit and 16-bit values, 862 // an 32-bit index width suffices. For 32-bit values, zero extending the 863 // elements into 64-bit loses some performance since the 32-bit indexed 864 // gather/scatter is more efficient than the 64-bit index variant (if the 865 // negative 32-bit index space is unused, the enableSIMDIndex32 flag can 866 // preserve this performance). For 64-bit values, there is no good way 867 // to state that the indices are unsigned, with creates the potential of 868 // incorrect address calculations in the unlikely case we need such 869 // extremely large offsets. 870 Type etp = ptr.getType().cast<MemRefType>().getElementType(); 871 Value vload = genVectorLoad(codegen, builder, ptr, {s}); 872 if (!etp.isa<IndexType>()) { 873 if (etp.getIntOrFloatBitWidth() < 32) 874 vload = builder.create<arith::ExtUIOp>( 875 loc, vectorType(codegen, builder.getI32Type()), vload); 876 else if (etp.getIntOrFloatBitWidth() < 64 && 877 !codegen.options.enableSIMDIndex32) 878 vload = builder.create<arith::ExtUIOp>( 879 loc, vectorType(codegen, builder.getI64Type()), vload); 880 } 881 return vload; 882 } 883 // For the scalar case, we simply zero extend narrower indices into 64-bit 884 // values before casting to index without a performance penalty. Here too, 885 // however, indices that already are 64-bit, in theory, cannot express the 886 // full range as explained above. 887 Value load = builder.create<memref::LoadOp>(loc, ptr, s); 888 if (!load.getType().isa<IndexType>()) { 889 if (load.getType().getIntOrFloatBitWidth() < 64) 890 load = builder.create<arith::ExtUIOp>(loc, builder.getI64Type(), load); 891 load = 892 builder.create<arith::IndexCastOp>(loc, builder.getIndexType(), load); 893 } 894 return load; 895 } 896 897 /// Generates an invariant value. 898 static Value genInvariantValue(Merger &merger, CodeGen &codegen, 899 OpBuilder &builder, unsigned exp) { 900 Value val = merger.exp(exp).val; 901 if (codegen.curVecLength > 1) 902 return genVectorInvariantValue(codegen, builder, val); 903 return val; 904 } 905 906 /// Generates an address computation "sz * p + i". 907 static Value genAddress(CodeGen &codegen, OpBuilder &builder, Location loc, 908 Value size, Value p, Value i) { 909 Value mul = builder.create<arith::MulIOp>(loc, size, p); 910 if (auto vtp = i.getType().dyn_cast<VectorType>()) { 911 Value inv = 912 builder.create<arith::IndexCastOp>(loc, vtp.getElementType(), mul); 913 mul = genVectorInvariantValue(codegen, builder, inv); 914 } 915 return builder.create<arith::AddIOp>(loc, mul, i); 916 } 917 918 /// Generates an index value. 919 static Value genIndexValue(CodeGen &codegen, OpBuilder &builder, unsigned idx, 920 unsigned ldx) { 921 Value ival = codegen.loops[idx]; 922 Type itype = ival.getType(); 923 // During vectorization, we either encounter: 924 // (1) indices already in vector form, as in ... = ind[lo:hi], good to go, or 925 // (2) single index, as in ... = i, must convert to [i, i+1, ...] for inner i. 926 unsigned vl = codegen.curVecLength; 927 if (vl > 1 && !itype.isa<VectorType>()) { 928 Location loc = ival.getLoc(); 929 VectorType vtp = vectorType(codegen, itype); 930 ival = builder.create<vector::BroadcastOp>(loc, vtp, ival); 931 if (idx == ldx) { 932 Value incr; 933 if (vtp.isScalable()) { 934 Type stepvty = vectorType(codegen, builder.getI64Type()); 935 Value stepv = builder.create<LLVM::StepVectorOp>(loc, stepvty); 936 incr = builder.create<arith::IndexCastOp>(loc, vtp, stepv); 937 } else { 938 SmallVector<APInt, 4> integers; 939 for (unsigned i = 0; i < vl; i++) 940 integers.push_back(APInt(/*width=*/64, i)); 941 auto values = DenseElementsAttr::get(vtp, integers); 942 incr = builder.create<arith::ConstantOp>(loc, vtp, values); 943 } 944 ival = builder.create<arith::AddIOp>(loc, ival, incr); 945 } 946 } 947 return ival; 948 } 949 950 /// Semi-ring branches are simply inlined by the sparse compiler. Prior 951 /// analysis has verified that all computations are "local" to the inlined 952 /// branch or otherwise invariantly defined outside the loop nest, with the 953 /// exception of index computations, which need to be relinked to actual 954 /// inlined cloned code. 955 static Value relinkBranch(CodeGen &codegen, RewriterBase &rewriter, 956 Block *block, Value e, unsigned ldx) { 957 if (Operation *def = e.getDefiningOp()) { 958 if (auto indexOp = dyn_cast<linalg::IndexOp>(def)) 959 return genIndexValue(codegen, rewriter, indexOp.dim(), ldx); 960 if (def->getBlock() == block) { 961 for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) 962 def->setOperand( 963 i, relinkBranch(codegen, rewriter, block, def->getOperand(i), ldx)); 964 } 965 } 966 return e; 967 } 968 969 /// Recursively generates tensor expression. 970 static Value genExp(Merger &merger, CodeGen &codegen, RewriterBase &rewriter, 971 linalg::GenericOp op, unsigned exp, unsigned ldx) { 972 Location loc = op.getLoc(); 973 if (exp == -1u) 974 return Value(); 975 if (merger.exp(exp).kind == Kind::kTensor) 976 return genTensorLoad(merger, codegen, rewriter, op, exp); 977 if (merger.exp(exp).kind == Kind::kInvariant) 978 return genInvariantValue(merger, codegen, rewriter, exp); 979 if (merger.exp(exp).kind == Kind::kIndex) 980 return genIndexValue(codegen, rewriter, merger.exp(exp).index, ldx); 981 Value v0 = 982 genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0, ldx); 983 Value v1 = 984 genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1, ldx); 985 Value ee = merger.buildExp(rewriter, loc, exp, v0, v1); 986 if (ee && (merger.exp(exp).kind == Kind::kUnary || 987 merger.exp(exp).kind == Kind::kBinary || 988 merger.exp(exp).kind == Kind::kBinaryBranch)) 989 ee = relinkBranch(codegen, rewriter, ee.getParentBlock(), ee, ldx); 990 return ee; 991 } 992 993 /// Determines if affine expression is invariant. 994 static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a, 995 unsigned ldx, bool &atLevel) { 996 switch (a.getKind()) { 997 case AffineExprKind::DimId: { 998 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 999 if (idx == ldx) 1000 atLevel = true; 1001 return codegen.loops[idx] != nullptr; // no longer in play? 1002 } 1003 case AffineExprKind::Add: 1004 case AffineExprKind::Mul: { 1005 auto binOp = a.cast<AffineBinaryOpExpr>(); 1006 return isInvariantAffine(codegen, binOp.getLHS(), ldx, atLevel) && 1007 isInvariantAffine(codegen, binOp.getRHS(), ldx, atLevel); 1008 } 1009 default: 1010 return true; 1011 } 1012 } 1013 1014 /// Hoists loop invariant tensor loads for which indices have been exhausted. 1015 static void genInvariants(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1016 linalg::GenericOp op, unsigned exp, unsigned ldx, 1017 bool atStart, Kind last = Kind::kTensor) { 1018 if (exp == -1u) 1019 return; 1020 if (merger.exp(exp).kind == Kind::kTensor) { 1021 // Inspect tensor indices. 1022 bool atLevel = ldx == -1u; 1023 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 1024 auto map = op.getTiedIndexingMap(t); 1025 auto enc = getSparseTensorEncoding(t->get().getType()); 1026 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 1027 AffineExpr a = map.getResult(perm(enc, d)); 1028 if (!isInvariantAffine(codegen, a, ldx, atLevel)) 1029 return; // still in play 1030 } 1031 // All exhausted at this level (atLevel denotes exactly at this level). 1032 if (!atLevel) 1033 return; 1034 OpOperand *lhs = op.getOutputOperand(0); 1035 if (lhs == t) { 1036 // Start or end a scalarized reduction 1037 if (atStart) { 1038 Value load = genTensorLoad(merger, codegen, builder, op, exp); 1039 codegen.redKind = getReduction(last); 1040 codegen.redExp = exp; 1041 updateReduc(merger, codegen, load); 1042 } else { 1043 Value redVal = codegen.redVal; 1044 updateReduc(merger, codegen, Value()); 1045 codegen.redExp = -1u; 1046 codegen.redKind = kNoReduc; 1047 genTensorStore(merger, codegen, builder, op, exp, redVal); 1048 } 1049 } else { 1050 // Start or end loop invariant hoisting of a tensor load. 1051 merger.exp(exp).val = 1052 atStart ? genTensorLoad(merger, codegen, builder, op, exp) : Value(); 1053 } 1054 } else if (merger.exp(exp).kind != Kind::kInvariant && 1055 merger.exp(exp).kind != Kind::kIndex) { 1056 // Traverse into the binary operations. Note that we only hoist 1057 // tensor loads, since subsequent MLIR/LLVM passes know how to 1058 // deal with all other kinds of derived loop invariants. 1059 Kind last = merger.exp(exp).kind; 1060 unsigned e0 = merger.exp(exp).children.e0; 1061 unsigned e1 = merger.exp(exp).children.e1; 1062 genInvariants(merger, codegen, builder, op, e0, ldx, atStart, last); 1063 genInvariants(merger, codegen, builder, op, e1, ldx, atStart, last); 1064 } 1065 } 1066 1067 /// Generates an expanded access pattern in innermost dimension. 1068 static void genExpansion(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1069 linalg::GenericOp op, unsigned at, bool atStart) { 1070 OpOperand *lhs = codegen.sparseOut; 1071 if (!lhs || codegen.outerParNest != op.getRank(lhs) - 1 || 1072 at != codegen.outerParNest) 1073 return; // not needed at this level 1074 // Generate start or end of an expanded access pattern. 1075 Value tensor = lhs->get(); 1076 Location loc = op.getLoc(); 1077 if (atStart) { 1078 auto dynShape = {ShapedType::kDynamicSize}; 1079 Type etp = tensor.getType().cast<ShapedType>().getElementType(); 1080 Type t1 = MemRefType::get(dynShape, etp); 1081 Type t2 = MemRefType::get(dynShape, builder.getI1Type()); 1082 Type t3 = MemRefType::get(dynShape, builder.getIndexType()); 1083 Type t4 = builder.getIndexType(); 1084 auto res = 1085 builder.create<ExpandOp>(loc, TypeRange({t1, t2, t3, t4}), tensor); 1086 assert(res.getNumResults() == 4); 1087 assert(!codegen.expValues); 1088 codegen.expValues = res.getResult(0); 1089 codegen.expFilled = res.getResult(1); 1090 codegen.expAdded = res.getResult(2); 1091 codegen.expCount = res.getResult(3); 1092 } else { 1093 assert(codegen.expValues); 1094 builder.create<CompressOp>(loc, tensor, codegen.lexIdx, codegen.expValues, 1095 codegen.expFilled, codegen.expAdded, 1096 codegen.expCount); 1097 codegen.expValues = codegen.expFilled = codegen.expAdded = 1098 codegen.expCount = Value(); 1099 } 1100 } 1101 1102 /// Generates initialization code for the subsequent loop sequence at 1103 /// current index level. Returns true if the loop sequence needs to 1104 /// maintain the universal index. 1105 static bool genInit(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1106 linalg::GenericOp op, std::vector<unsigned> &topSort, 1107 unsigned at, BitVector &inits) { 1108 bool needsUniv = false; 1109 Location loc = op.getLoc(); 1110 unsigned idx = topSort[at]; 1111 1112 // Initialize sparse positions. 1113 for (unsigned b = 0, be = inits.size(); b < be; b++) { 1114 if (inits[b]) { 1115 unsigned tensor = merger.tensor(b); 1116 assert(idx == merger.index(b)); 1117 if (merger.isDim(b, Dim::kSparse)) { 1118 // Initialize sparse index. 1119 unsigned pat = at; 1120 for (; pat != 0; pat--) { 1121 if (codegen.pidxs[tensor][topSort[pat - 1]]) 1122 break; 1123 } 1124 Value ptr = codegen.pointers[tensor][idx]; 1125 Value one = constantIndex(builder, loc, 1); 1126 Value p0 = (pat == 0) ? constantIndex(builder, loc, 0) 1127 : codegen.pidxs[tensor][topSort[pat - 1]]; 1128 codegen.pidxs[tensor][idx] = genLoad(codegen, builder, loc, ptr, p0); 1129 Value p1 = builder.create<arith::AddIOp>(loc, p0, one); 1130 codegen.highs[tensor][idx] = genLoad(codegen, builder, loc, ptr, p1); 1131 } else { 1132 // Dense index still in play. 1133 needsUniv = true; 1134 } 1135 } 1136 } 1137 1138 // Initialize the universal dense index. 1139 codegen.loops[idx] = constantIndex(builder, loc, 0); 1140 return needsUniv; 1141 } 1142 1143 /// Returns vectorization strategy. Any implicit inner loop in the Linalg 1144 /// operation is a candidate. Whether it is actually converted to SIMD code 1145 /// depends on the requested strategy. 1146 static bool isVectorFor(CodeGen &codegen, bool isInner, bool isReduction, 1147 bool isSparse) { 1148 // Reject vectorization of sparse output, unless innermost is reduction. 1149 if (codegen.sparseOut && !isReduction) 1150 return false; 1151 // Inspect strategy. 1152 switch (codegen.options.vectorizationStrategy) { 1153 case SparseVectorizationStrategy::kNone: 1154 return false; 1155 case SparseVectorizationStrategy::kDenseInnerLoop: 1156 return isInner && !isSparse; 1157 case SparseVectorizationStrategy::kAnyStorageInnerLoop: 1158 return isInner; 1159 } 1160 llvm_unreachable("unexpected vectorization strategy"); 1161 } 1162 1163 /// Returns parallelization strategy. Any implicit loop in the Linalg operation 1164 /// that is marked "parallel" is a candidate. Whether it is actually converted 1165 /// to a parallel operation depends on the requested strategy. 1166 static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction, 1167 bool isSparse, bool isVector) { 1168 // Reject parallelization of sparse output. 1169 if (codegen.sparseOut) 1170 return false; 1171 // Inspect strategy. 1172 switch (codegen.options.parallelizationStrategy) { 1173 case SparseParallelizationStrategy::kNone: 1174 return false; 1175 case SparseParallelizationStrategy::kDenseOuterLoop: 1176 return isOuter && !isSparse && !isReduction && !isVector; 1177 case SparseParallelizationStrategy::kAnyStorageOuterLoop: 1178 return isOuter && !isReduction && !isVector; 1179 case SparseParallelizationStrategy::kDenseAnyLoop: 1180 return !isSparse && !isReduction && !isVector; 1181 case SparseParallelizationStrategy::kAnyStorageAnyLoop: 1182 return !isReduction && !isVector; 1183 } 1184 llvm_unreachable("unexpected parallelization strategy"); 1185 } 1186 1187 /// Checks unit stride for dense tensors. The iteration graph may have ignored 1188 /// dense access patterns in order to avoid cycles (sparse access patterns are 1189 /// always placed innermost), but that means dense access has become strided. 1190 /// This prevents effective vectorization. 1191 static bool denseUnitStrides(Merger &merger, linalg::GenericOp op, 1192 unsigned idx) { 1193 for (OpOperand *t : op.getInputAndOutputOperands()) { 1194 if (!getSparseTensorEncoding(t->get().getType())) { 1195 auto map = op.getTiedIndexingMap(t); 1196 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 1197 AffineExpr a = map.getResult(d); 1198 // Report non-unit stride if innermost index appears at an outer 1199 // dimension (true non-unit stride) or if the innermost index appears 1200 // in a compound subscript in the innermost dimension. Even if the 1201 // latter is unit stride, it does not play well with scatter/gather. 1202 // TODO: accept unit stride affine innermost like a[i,j+k+1]? 1203 if (a.isFunctionOfDim(idx) && 1204 ((d != rank - 1) || (a.getKind() != AffineExprKind::DimId))) 1205 return false; 1206 } 1207 } 1208 } 1209 return true; 1210 } 1211 1212 /// Generates a for-loop on a single index. 1213 static Operation *genFor(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1214 linalg::GenericOp op, bool isOuter, bool isInner, 1215 unsigned idx, BitVector &indices) { 1216 unsigned fb = indices.find_first(); 1217 unsigned tensor = merger.tensor(fb); 1218 assert(idx == merger.index(fb)); 1219 auto iteratorTypes = op.iterator_types().getValue(); 1220 bool isReduction = isReductionIterator(iteratorTypes[idx]); 1221 bool isSparse = merger.isDim(fb, Dim::kSparse); 1222 bool isVector = isVectorFor(codegen, isInner, isReduction, isSparse) && 1223 denseUnitStrides(merger, op, idx); 1224 bool isParallel = 1225 isParallelFor(codegen, isOuter, isReduction, isSparse, isVector); 1226 1227 // Prepare vector length. 1228 if (isVector) 1229 codegen.curVecLength = codegen.options.vectorLength; 1230 1231 // Loop bounds and increment. 1232 Location loc = op.getLoc(); 1233 Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx]; 1234 Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx]; 1235 Value step = constantIndex(builder, loc, codegen.curVecLength); 1236 if (isVector && codegen.options.enableVLAVectorization) { 1237 Value vscale = builder.create<vector::VectorScaleOp>( 1238 loc, IndexType::get(builder.getContext())); 1239 step = builder.create<arith::MulIOp>(loc, vscale, step); 1240 } 1241 1242 // Emit a parallel loop. 1243 if (isParallel) { 1244 assert(!isVector); 1245 scf::ParallelOp parOp = builder.create<scf::ParallelOp>(loc, lo, hi, step); 1246 if (isSparse) 1247 codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0]; 1248 else 1249 codegen.loops[idx] = parOp.getInductionVars()[0]; 1250 builder.setInsertionPointToStart(parOp.getBody()); 1251 return parOp; 1252 } 1253 1254 // Emit a sequential or vector loop. 1255 SmallVector<Value, 4> operands; 1256 if (codegen.redVal) { 1257 // In a vector loop, bring reduction into SIMD form, if not already. 1258 if (isVector && !codegen.redVal.getType().isa<VectorType>()) { 1259 VectorType vtp = vectorType(codegen, codegen.redVal.getType()); 1260 Value vred = genVectorReducInit(codegen, builder, loc, vtp); 1261 updateReduc(merger, codegen, vred); 1262 } 1263 operands.push_back(codegen.redVal); 1264 } 1265 if (codegen.expValues) 1266 operands.push_back(codegen.expCount); 1267 scf::ForOp forOp = builder.create<scf::ForOp>(loc, lo, hi, step, operands); 1268 if (codegen.redVal) 1269 updateReduc(merger, codegen, forOp.getRegionIterArgs().front()); 1270 if (codegen.expValues) 1271 codegen.expCount = forOp.getRegionIterArgs().back(); 1272 // Assign induction variable to sparse or dense index. 1273 Value iv = forOp.getInductionVar(); 1274 if (isSparse) 1275 codegen.pidxs[tensor][idx] = iv; 1276 else 1277 codegen.loops[idx] = iv; 1278 builder.setInsertionPointToStart(forOp.getBody()); 1279 // Share vector iteration mask between all subsequent loads/stores. 1280 if (isVector) 1281 codegen.curVecMask = genVectorMask(codegen, builder, iv, lo, hi, step); 1282 return forOp; 1283 } 1284 1285 /// Emit a while-loop for co-iteration over multiple indices. 1286 static Operation *genWhile(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1287 linalg::GenericOp op, unsigned idx, bool needsUniv, 1288 BitVector &indices) { 1289 SmallVector<Type, 4> types; 1290 SmallVector<Value, 4> operands; 1291 // Construct the while-loop with a parameter for each index. 1292 Type indexType = builder.getIndexType(); 1293 for (unsigned b = 0, be = indices.size(); b < be; b++) { 1294 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1295 unsigned tensor = merger.tensor(b); 1296 assert(idx == merger.index(b)); 1297 types.push_back(indexType); 1298 operands.push_back(codegen.pidxs[tensor][idx]); 1299 } 1300 } 1301 if (codegen.redVal) { 1302 types.push_back(codegen.redVal.getType()); 1303 operands.push_back(codegen.redVal); 1304 } 1305 if (codegen.expValues) { 1306 types.push_back(indexType); 1307 operands.push_back(codegen.expCount); 1308 } 1309 if (needsUniv) { 1310 types.push_back(indexType); 1311 operands.push_back(codegen.loops[idx]); 1312 } 1313 assert(types.size() == operands.size()); 1314 Location loc = op.getLoc(); 1315 scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands); 1316 1317 SmallVector<Location> locs(types.size(), loc); 1318 Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs); 1319 Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs); 1320 1321 // Build the "before" region, which effectively consists 1322 // of a conjunction of "i < upper" tests on all induction. 1323 builder.setInsertionPointToStart(&whileOp.getBefore().front()); 1324 Value cond; 1325 unsigned o = 0; 1326 for (unsigned b = 0, be = indices.size(); b < be; b++) { 1327 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1328 unsigned tensor = merger.tensor(b); 1329 assert(idx == merger.index(b)); 1330 Value op1 = before->getArgument(o); 1331 Value op2 = codegen.highs[tensor][idx]; 1332 Value opc = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 1333 op1, op2); 1334 cond = cond ? builder.create<arith::AndIOp>(loc, cond, opc) : opc; 1335 codegen.pidxs[tensor][idx] = after->getArgument(o++); 1336 } 1337 } 1338 if (codegen.redVal) 1339 updateReduc(merger, codegen, after->getArgument(o++)); 1340 if (codegen.expValues) 1341 codegen.expCount = after->getArgument(o++); 1342 if (needsUniv) 1343 codegen.loops[idx] = after->getArgument(o++); 1344 assert(o == operands.size()); 1345 builder.create<scf::ConditionOp>(loc, cond, before->getArguments()); 1346 builder.setInsertionPointToStart(&whileOp.getAfter().front()); 1347 return whileOp; 1348 } 1349 1350 /// Generates a for-loop or a while-loop, depending on whether it implements 1351 /// singleton iteration or co-iteration over the given conjunction. 1352 static Operation *genLoop(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1353 linalg::GenericOp op, std::vector<unsigned> &topSort, 1354 unsigned at, bool needsUniv, BitVector &indices) { 1355 unsigned idx = topSort[at]; 1356 if (indices.count() == 1) { 1357 bool isOuter = at == 0; 1358 bool isInner = at == topSort.size() - 1; 1359 return genFor(merger, codegen, builder, op, isOuter, isInner, idx, indices); 1360 } 1361 return genWhile(merger, codegen, builder, op, idx, needsUniv, indices); 1362 } 1363 1364 /// Generates the local variables for this loop, consisting of the sparse 1365 /// indices, restored universal dense index, and dense positions. 1366 static void genLocals(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1367 linalg::GenericOp op, std::vector<unsigned> &topSort, 1368 unsigned at, bool needsUniv, BitVector &locals) { 1369 Location loc = op.getLoc(); 1370 unsigned idx = topSort[at]; 1371 1372 // Initialize sparse indices. 1373 Value min; 1374 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1375 if (locals[b] && merger.isDim(b, Dim::kSparse)) { 1376 unsigned tensor = merger.tensor(b); 1377 assert(idx == merger.index(b)); 1378 Value ptr = codegen.indices[tensor][idx]; 1379 Value s = codegen.pidxs[tensor][idx]; 1380 Value load = genLoad(codegen, builder, loc, ptr, s); 1381 codegen.idxs[tensor][idx] = load; 1382 if (!needsUniv) { 1383 if (min) { 1384 Value cmp = builder.create<arith::CmpIOp>( 1385 loc, arith::CmpIPredicate::ult, load, min); 1386 min = builder.create<arith::SelectOp>(loc, cmp, load, min); 1387 } else { 1388 min = load; 1389 } 1390 } 1391 } 1392 } 1393 1394 // Merge dense universal index over minimum. 1395 if (min) { 1396 assert(!needsUniv); 1397 codegen.loops[idx] = min; 1398 } 1399 1400 // Initialize dense positions. Note that we generate dense indices of the 1401 // output tensor unconditionally, since they may not appear in the lattice, 1402 // but may be needed for linearized codegen. 1403 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1404 if ((locals[b] || merger.isOutTensor(b, idx)) && 1405 merger.isDim(b, Dim::kDense)) { 1406 unsigned tensor = merger.tensor(b); 1407 assert(idx == merger.index(b)); 1408 unsigned pat = at; 1409 for (; pat != 0; pat--) 1410 if (codegen.pidxs[tensor][topSort[pat - 1]]) 1411 break; 1412 Value p = (pat == 0) ? constantIndex(builder, loc, 0) 1413 : codegen.pidxs[tensor][topSort[pat - 1]]; 1414 codegen.pidxs[tensor][idx] = genAddress( 1415 codegen, builder, loc, codegen.sizes[idx], p, codegen.loops[idx]); 1416 } 1417 } 1418 1419 // Move the insertion indices in lexicographic index order. During access 1420 // pattern expansion, we can skip setting the innermost dimension. 1421 if (codegen.sparseOut && !codegen.expValues) { 1422 Value pos = constantIndex(builder, loc, at); 1423 builder.create<memref::StoreOp>(loc, codegen.loops[idx], codegen.lexIdx, 1424 pos); 1425 } 1426 } 1427 1428 /// Generates the induction structure for a while-loop. 1429 static void genWhileInduction(Merger &merger, CodeGen &codegen, 1430 OpBuilder &builder, linalg::GenericOp op, 1431 unsigned idx, bool needsUniv, 1432 BitVector &induction, scf::WhileOp whileOp) { 1433 Location loc = op.getLoc(); 1434 // Finalize each else branch of all if statements. 1435 if (codegen.redVal || codegen.expValues) { 1436 while (auto ifOp = dyn_cast_or_null<scf::IfOp>( 1437 builder.getInsertionBlock()->getParentOp())) { 1438 unsigned y = 0; 1439 SmallVector<Value, 4> yields; 1440 if (codegen.redVal) { 1441 yields.push_back(codegen.redVal); 1442 updateReduc(merger, codegen, ifOp.getResult(y++)); 1443 } 1444 if (codegen.expValues) { 1445 yields.push_back(codegen.expCount); 1446 codegen.expCount = ifOp->getResult(y++); 1447 } 1448 assert(y == yields.size()); 1449 builder.create<scf::YieldOp>(loc, yields); 1450 builder.setInsertionPointAfter(ifOp); 1451 } 1452 } 1453 builder.setInsertionPointToEnd(&whileOp.getAfter().front()); 1454 // Finalize the induction. Note that the induction could be performed 1455 // in the individual if-branches to avoid re-evaluating the conditions. 1456 // However, that would result in a rather elaborate forest of yield 1457 // instructions during code generation. Moreover, performing the induction 1458 // after the if-statements more closely resembles code generated by TACO. 1459 unsigned o = 0; 1460 SmallVector<Value, 4> operands; 1461 Value one = constantIndex(builder, loc, 1); 1462 for (unsigned b = 0, be = induction.size(); b < be; b++) { 1463 if (induction[b] && merger.isDim(b, Dim::kSparse)) { 1464 unsigned tensor = merger.tensor(b); 1465 assert(idx == merger.index(b)); 1466 Value op1 = codegen.idxs[tensor][idx]; 1467 Value op2 = codegen.loops[idx]; 1468 Value op3 = codegen.pidxs[tensor][idx]; 1469 Value cmp = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 1470 op1, op2); 1471 Value add = builder.create<arith::AddIOp>(loc, op3, one); 1472 operands.push_back(builder.create<arith::SelectOp>(loc, cmp, add, op3)); 1473 codegen.pidxs[tensor][idx] = whileOp->getResult(o++); 1474 } 1475 } 1476 if (codegen.redVal) { 1477 operands.push_back(codegen.redVal); 1478 updateReduc(merger, codegen, whileOp->getResult(o++)); 1479 } 1480 if (codegen.expValues) { 1481 operands.push_back(codegen.expCount); 1482 codegen.expCount = whileOp->getResult(o++); 1483 } 1484 if (needsUniv) { 1485 operands.push_back( 1486 builder.create<arith::AddIOp>(loc, codegen.loops[idx], one)); 1487 codegen.loops[idx] = whileOp->getResult(o++); 1488 } 1489 assert(o == operands.size()); 1490 builder.create<scf::YieldOp>(loc, operands); 1491 builder.setInsertionPointAfter(whileOp); 1492 } 1493 1494 /// Generates the induction structure for a for-loop. 1495 static void genForInduction(Merger &merger, CodeGen &codegen, 1496 OpBuilder &builder, linalg::GenericOp op, 1497 Operation *loop) { 1498 Location loc = op.getLoc(); 1499 unsigned o = 0; 1500 SmallVector<Value, 4> operands; 1501 if (codegen.redVal) { 1502 operands.push_back(codegen.redVal); 1503 updateReduc(merger, codegen, loop->getResult(o++)); 1504 } 1505 if (codegen.expValues) { 1506 operands.push_back(codegen.expCount); 1507 codegen.expCount = loop->getResult(o++); 1508 } 1509 assert(o == operands.size()); 1510 if (o > 0) 1511 builder.create<scf::YieldOp>(loc, operands); 1512 builder.setInsertionPointAfter(loop); 1513 } 1514 1515 /// Generates a single if-statement within a while-loop. 1516 static scf::IfOp genIf(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1517 linalg::GenericOp op, unsigned idx, 1518 BitVector &conditions) { 1519 Location loc = op.getLoc(); 1520 SmallVector<Type, 4> types; 1521 Value cond; 1522 for (unsigned b = 0, be = conditions.size(); b < be; b++) { 1523 if (conditions[b]) { 1524 unsigned tensor = merger.tensor(b); 1525 assert(idx == merger.index(b)); 1526 Value clause; 1527 if (merger.isDim(b, Dim::kSparse)) { 1528 Value op1 = codegen.idxs[tensor][idx]; 1529 Value op2 = codegen.loops[idx]; 1530 clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 1531 op1, op2); 1532 } else { 1533 clause = constantI1(builder, loc, true); 1534 } 1535 cond = cond ? builder.create<arith::AndIOp>(loc, cond, clause) : clause; 1536 } 1537 } 1538 if (codegen.redVal) 1539 types.push_back(codegen.redVal.getType()); 1540 if (codegen.expValues) 1541 types.push_back(builder.getIndexType()); 1542 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true); 1543 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 1544 return ifOp; 1545 } 1546 1547 /// Generates end of true branch of if-statement within a while-loop. 1548 static void endIf(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1549 linalg::GenericOp op, scf::IfOp ifOp, Operation *loop, 1550 Value redInput, Value cntInput) { 1551 SmallVector<Value, 4> operands; 1552 if (codegen.redVal) { 1553 operands.push_back(codegen.redVal); 1554 updateReduc(merger, codegen, redInput); 1555 } 1556 if (codegen.expValues) { 1557 operands.push_back(codegen.expCount); 1558 codegen.expCount = cntInput; 1559 } 1560 if (!operands.empty()) 1561 builder.create<scf::YieldOp>(op.getLoc(), operands); 1562 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 1563 } 1564 1565 //===----------------------------------------------------------------------===// 1566 // Sparse compiler synthesis methods (loop sequence). 1567 //===----------------------------------------------------------------------===// 1568 1569 /// Starts a loop sequence at given level. Returns true if 1570 /// the universal loop index must be maintained at this level. 1571 static bool startLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1572 linalg::GenericOp op, std::vector<unsigned> &topSort, 1573 unsigned exp, unsigned at, unsigned idx, unsigned ldx, 1574 unsigned lts) { 1575 assert(codegen.curVecLength == 1); 1576 assert(!codegen.loops[idx]); 1577 // Emit invariants at this loop sequence level. 1578 genInvariants(merger, codegen, builder, op, exp, ldx, /*atStart=*/true); 1579 // Emit access pattern expansion for sparse tensor output. 1580 genExpansion(merger, codegen, builder, op, at, /*atStart=*/true); 1581 // Emit further intitialization at this loop sequence level. 1582 unsigned l0 = merger.set(lts)[0]; 1583 bool needsUniv = 1584 genInit(merger, codegen, builder, op, topSort, at, merger.lat(l0).bits); 1585 // Maintain the universal index only if it is actually 1586 // consumed by a subsequent lattice point. 1587 if (needsUniv) { 1588 unsigned lsize = merger.set(lts).size(); 1589 for (unsigned i = 1; i < lsize; i++) { 1590 unsigned li = merger.set(lts)[i]; 1591 if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) 1592 return true; 1593 } 1594 } 1595 return false; 1596 } 1597 1598 /// Starts a single loop in current sequence. 1599 static Operation *startLoop(Merger &merger, CodeGen &codegen, 1600 OpBuilder &builder, linalg::GenericOp op, 1601 std::vector<unsigned> &topSort, unsigned at, 1602 unsigned li, bool needsUniv) { 1603 assert(codegen.curVecLength == 1); 1604 // Emit the for/while-loop control. 1605 Operation *loop = genLoop(merger, codegen, builder, op, topSort, at, 1606 needsUniv, merger.lat(li).simple); 1607 // Emit the locals for this loop. 1608 genLocals(merger, codegen, builder, op, topSort, at, needsUniv, 1609 merger.lat(li).bits); 1610 return loop; 1611 } 1612 1613 /// Ends a single loop in current sequence. Returns new values for needsUniv. 1614 static bool endLoop(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1615 linalg::GenericOp op, Operation *loop, unsigned idx, 1616 unsigned li, bool needsUniv) { 1617 codegen.curVecLength = 1; 1618 // End a while-loop. 1619 if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) { 1620 genWhileInduction(merger, codegen, builder, op, idx, needsUniv, 1621 merger.lat(li).bits, whileOp); 1622 return needsUniv; 1623 } 1624 // End a for-loop. 1625 genForInduction(merger, codegen, builder, op, loop); 1626 return false; 1627 } 1628 1629 /// Ends a loop sequence at given level. 1630 static void endLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1631 linalg::GenericOp op, unsigned exp, unsigned at, 1632 unsigned idx, unsigned ldx) { 1633 assert(codegen.curVecLength == 1); 1634 codegen.loops[idx] = Value(); 1635 // Bring a pending reduction back from SIMD form when sequence ends. 1636 if (codegen.redVal) 1637 if (auto vtp = codegen.redVal.getType().dyn_cast<VectorType>()) 1638 updateReduc(merger, codegen, 1639 genVectorReducEnd(codegen, builder, op.getLoc(), vtp)); 1640 // Unmark bookkeeping of invariants and loop index. 1641 genInvariants(merger, codegen, builder, op, exp, ldx, /*atStart=*/false); 1642 // Finalize access pattern expansion for sparse tensor output. 1643 genExpansion(merger, codegen, builder, op, at, /*atStart=*/false); 1644 } 1645 1646 /// Recursively generates code while computing iteration lattices in order 1647 /// to manage the complexity of implementing co-iteration over unions 1648 /// and intersections of sparse iterations spaces. 1649 static void genStmt(Merger &merger, CodeGen &codegen, RewriterBase &rewriter, 1650 linalg::GenericOp op, std::vector<unsigned> &topSort, 1651 unsigned exp, unsigned at) { 1652 // At each leaf, assign remaining tensor (sub)expression to output tensor. 1653 if (at == topSort.size()) { 1654 unsigned ldx = topSort[at - 1]; 1655 Value rhs = genExp(merger, codegen, rewriter, op, exp, ldx); 1656 genTensorStore(merger, codegen, rewriter, op, exp, rhs); 1657 return; 1658 } 1659 1660 // Construct iteration lattices for current loop index, with L0 at top. 1661 unsigned idx = topSort[at]; 1662 unsigned ldx = at == 0 ? -1u : topSort[at - 1]; 1663 unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx)); 1664 1665 // Start a loop sequence. 1666 bool needsUniv = startLoopSeq(merger, codegen, rewriter, op, topSort, exp, at, 1667 idx, ldx, lts); 1668 1669 // Emit a loop for every lattice point L0 >= Li in this loop sequence. 1670 unsigned lsize = merger.set(lts).size(); 1671 for (unsigned i = 0; i < lsize; i++) { 1672 // Start a loop. 1673 unsigned li = merger.set(lts)[i]; 1674 Operation *loop = 1675 startLoop(merger, codegen, rewriter, op, topSort, at, li, needsUniv); 1676 1677 // Visit all lattices points with Li >= Lj to generate the 1678 // loop-body, possibly with if statements for coiteration. 1679 Value redInput = codegen.redVal; 1680 Value cntInput = codegen.expCount; 1681 bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr; 1682 for (unsigned j = 0; j < lsize; j++) { 1683 unsigned lj = merger.set(lts)[j]; 1684 unsigned ej = merger.lat(lj).exp; 1685 if (li == lj || merger.latGT(li, lj)) { 1686 // Recurse into body of each branch. 1687 if (isWhile) { 1688 scf::IfOp ifOp = 1689 genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple); 1690 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1691 endIf(merger, codegen, rewriter, op, ifOp, loop, redInput, cntInput); 1692 } else { 1693 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1694 } 1695 } 1696 } 1697 1698 // End a loop. 1699 needsUniv = 1700 endLoop(merger, codegen, rewriter, op, loop, idx, li, needsUniv); 1701 } 1702 1703 // End a loop sequence. 1704 endLoopSeq(merger, codegen, rewriter, op, exp, at, idx, ldx); 1705 } 1706 1707 /// Converts the result computed by the sparse kernel into the required form. 1708 static void genResult(Merger &merger, CodeGen &codegen, RewriterBase &rewriter, 1709 linalg::GenericOp op) { 1710 OpOperand *lhs = op.getOutputOperand(0); 1711 Type resType = lhs->get().getType(); 1712 if (getSparseTensorEncoding(resType)) { 1713 // The sparse tensor rematerializes from the original sparse tensor's 1714 // underlying sparse storage format. 1715 rewriter.replaceOpWithNewOp<LoadOp>(op, resType, lhs->get(), 1716 codegen.sparseOut == lhs); 1717 } else { 1718 // To rematerialize an non-annotated tensor, simply load it 1719 // from the bufferized value. 1720 Value val = codegen.buffers.back(); // value array 1721 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val); 1722 } 1723 } 1724 1725 //===----------------------------------------------------------------------===// 1726 // Sparse compiler rewriting methods. 1727 //===----------------------------------------------------------------------===// 1728 1729 namespace { 1730 1731 /// Sparse rewriting rule for generic Lingalg operation. 1732 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> { 1733 public: 1734 GenericOpSparsifier(MLIRContext *context, SparsificationOptions o) 1735 : OpRewritePattern<linalg::GenericOp>(context), options(o) {} 1736 1737 LogicalResult matchAndRewrite(linalg::GenericOp op, 1738 PatternRewriter &rewriter) const override { 1739 // Detects sparse annotations and translate the per-dimension sparsity 1740 // information for all tensors to loop indices in the kernel. 1741 assert(op.getNumOutputs() == 1); 1742 unsigned numTensors = op.getNumInputsAndOutputs(); 1743 unsigned numLoops = op.iterator_types().getValue().size(); 1744 Merger merger(numTensors, numLoops); 1745 if (!findSparseAnnotations(merger, op)) 1746 return failure(); 1747 1748 // Computes a topologically sorted iteration graph to ensure tensors 1749 // are visited in natural index order. Gradually relaxes the considered 1750 // constraints until an acyclic iteration graph results, such that sparse 1751 // code generation can proceed. As a last resort, an attempt is made 1752 // to resolve cycles by inserting a conversion. 1753 std::vector<unsigned> topSort; 1754 if (!computeIterationGraph(merger, op, topSort, SortMask::kIncludeAll) && 1755 !computeIterationGraph(merger, op, topSort, SortMask::kIncludeUndef) && 1756 !computeIterationGraph(merger, op, topSort, SortMask::kIncludeDense) && 1757 !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly)) { 1758 return resolveCycle(merger, rewriter, op); 1759 } 1760 1761 // Builds the tensor expression for the Linalg operation in SSA form. 1762 Optional<unsigned> optExp = merger.buildTensorExpFromLinalg(op); 1763 if (!optExp.hasValue()) 1764 return failure(); 1765 unsigned exp = optExp.getValue(); 1766 1767 // Rejects an inadmissable tensor expression. 1768 OpOperand *sparseOut = nullptr; 1769 unsigned outerParNest = 0; 1770 if (!isAdmissableTensorExp(merger, op, topSort, exp, &sparseOut, 1771 outerParNest)) 1772 return failure(); 1773 1774 // Recursively generates code. 1775 merger.setHasSparseOut(sparseOut != nullptr); 1776 CodeGen codegen(options, numTensors, numLoops, sparseOut, outerParNest); 1777 genBuffers(merger, codegen, rewriter, op); 1778 genStmt(merger, codegen, rewriter, op, topSort, exp, 0); 1779 genResult(merger, codegen, rewriter, op); 1780 return success(); 1781 } 1782 1783 private: 1784 // Last resort cycle resolution. 1785 LogicalResult resolveCycle(Merger &merger, PatternRewriter &rewriter, 1786 linalg::GenericOp op) const { 1787 // Compute topological sort while leaving out every 1788 // sparse input tensor in succession until an acylic 1789 // iteration graph results. 1790 std::vector<unsigned> topSort; 1791 for (OpOperand *t : op.getInputOperands()) { 1792 unsigned tensor = t->getOperandNumber(); 1793 Value tval = t->get(); 1794 auto srcEnc = getSparseTensorEncoding(tval.getType()); 1795 if (!srcEnc || 1796 !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly, t)) 1797 continue; 1798 // Found an input tensor that resolves the cycle by inserting a 1799 // conversion into a sparse tensor that adheres to the iteration 1800 // graph order. Also releases the temporary sparse tensor. 1801 // 1802 // TODO: investigate fusing the conversion with computation, 1803 // especially if it is a direct yield! 1804 // 1805 auto srcTp = tval.getType().cast<RankedTensorType>(); 1806 auto dstEnc = SparseTensorEncodingAttr::get( 1807 op->getContext(), srcEnc.getDimLevelType(), 1808 permute(getContext(), op.getTiedIndexingMap(t), topSort), // new order 1809 srcEnc.getPointerBitWidth(), srcEnc.getIndexBitWidth()); 1810 auto dstTp = RankedTensorType::get(srcTp.getShape(), 1811 srcTp.getElementType(), dstEnc); 1812 auto convert = rewriter.create<ConvertOp>(tval.getLoc(), dstTp, tval); 1813 op->setOperand(tensor, convert); 1814 rewriter.setInsertionPointAfter(op); 1815 rewriter.create<ReleaseOp>(tval.getLoc(), convert); 1816 return success(); 1817 } 1818 // Cannot be resolved with a single conversion. 1819 // TODO: convert more than one? 1820 return failure(); 1821 } 1822 1823 /// Options to control sparse code generation. 1824 SparsificationOptions options; 1825 }; 1826 1827 } // namespace 1828 1829 /// Populates the given patterns list with rewriting rules required for 1830 /// the sparsification of linear algebra operations. 1831 void mlir::populateSparsificationPatterns( 1832 RewritePatternSet &patterns, const SparsificationOptions &options) { 1833 patterns.add<GenericOpSparsifier>(patterns.getContext(), options); 1834 } 1835