1 //===- Sparsification.cpp - Implementation of sparsification --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements converting sparse tensor types to actual sparse code. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "CodegenUtils.h" 14 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 18 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 19 #include "mlir/Dialect/Func/IR/FuncOps.h" 20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 21 #include "mlir/Dialect/Linalg/IR/Linalg.h" 22 #include "mlir/Dialect/Linalg/Utils/Utils.h" 23 #include "mlir/Dialect/MemRef/IR/MemRef.h" 24 #include "mlir/Dialect/SCF/IR/SCF.h" 25 #include "mlir/Dialect/SCF/Transforms/Transforms.h" 26 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 27 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 28 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 29 #include "mlir/Dialect/Tensor/IR/Tensor.h" 30 #include "mlir/Dialect/Vector/IR/VectorOps.h" 31 #include "mlir/IR/Matchers.h" 32 #include "mlir/IR/TensorEncoding.h" 33 #include "llvm/ADT/SmallBitVector.h" 34 35 using namespace mlir; 36 using namespace mlir::sparse_tensor; 37 38 //===----------------------------------------------------------------------===// 39 // Declarations of data structures. 40 //===----------------------------------------------------------------------===// 41 42 namespace { 43 44 // Iteration graph sorting. 45 enum SortMask { 46 kSparseOnly = 0x0, 47 kIncludeDense = 0x1, 48 kIncludeUndef = 0x2, 49 kIncludeAll = 0x3 50 }; 51 52 // Reduction kinds. 53 enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor }; 54 55 // Code generation. 56 struct CodeGen { 57 CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops, 58 OpOperand *op, unsigned nest) 59 : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors), 60 pointers(numTensors, std::vector<Value>(numLoops)), 61 indices(numTensors, std::vector<Value>(numLoops)), 62 highs(numTensors, std::vector<Value>(numLoops)), 63 pidxs(numTensors, std::vector<Value>(numLoops)), 64 idxs(numTensors, std::vector<Value>(numLoops)), redVal(), sparseOut(op), 65 outerParNest(nest), lexIdx(), lexVal(), expValues(), expFilled(), 66 expAdded(), expCount(), curVecMask() {} 67 /// Sparsification options. 68 SparsificationOptions options; 69 /// Universal dense indices and upper bounds (by index). The loops array 70 /// is updated with the value of the universal dense index in the current 71 /// loop. The sizes array is set once with the inferred dimension sizes. 72 std::vector<Value> loops; 73 std::vector<Value> sizes; 74 /// Buffers for storing dense and sparse numerical values (by tensor). 75 /// This array is set once during bufferization of all tensors. 76 std::vector<Value> buffers; 77 /// Sparse storage schemes (1-D): pointers and indices (by tensor and index). 78 /// This array is set once during bufferization of all sparse tensors. 79 std::vector<std::vector<Value>> pointers; 80 std::vector<std::vector<Value>> indices; 81 /// Sparse iteration information (by tensor and index). These arrays 82 /// are updated to remain current within the current loop. 83 std::vector<std::vector<Value>> highs; 84 std::vector<std::vector<Value>> pidxs; 85 std::vector<std::vector<Value>> idxs; 86 /// Current reduction, updated during code generation. When indices of a 87 /// reduction are exhausted, all inner loops can use a scalarized reduction. 88 unsigned redExp = -1u; 89 Value redVal; 90 Reduction redKind = kNoReduc; 91 // Sparse tensor as output. Implemented either through direct injective 92 // insertion in lexicographic index order (where indices are updated 93 // in the temporary array `lexIdx`) or through access pattern expansion 94 // in the innermost loop nest (`expValues` through `expCount`). 95 OpOperand *sparseOut; 96 unsigned outerParNest; 97 Value lexIdx; 98 Value lexVal; 99 Value expValues; 100 Value expFilled; 101 Value expAdded; 102 Value expCount; 103 // Current vector length and mask. 104 unsigned curVecLength = 1; 105 Value curVecMask; 106 }; 107 108 } // namespace 109 110 //===----------------------------------------------------------------------===// 111 // Sparse compiler analysis methods. 112 //===----------------------------------------------------------------------===// 113 114 /// Helper method to construct a permuted dimension ordering 115 /// that adheres to the given topological sort. 116 static AffineMap permute(MLIRContext *context, AffineMap m, 117 std::vector<unsigned> &topSort) { 118 unsigned sz = topSort.size(); 119 assert(m.getNumResults() == sz && "TopoSort/AffineMap size mismatch"); 120 // Construct the inverse of `m`; to avoid the asymptotic complexity 121 // of calling `m.getPermutedPosition` repeatedly. 122 SmallVector<unsigned, 4> inv(sz); 123 for (unsigned i = 0; i < sz; i++) 124 inv[i] = m.getDimPosition(i); 125 // Construct the permutation. 126 SmallVector<unsigned, 4> perm(sz); 127 for (unsigned i = 0; i < sz; i++) 128 perm[i] = inv[topSort[i]]; 129 return AffineMap::getPermutationMap(perm, context); 130 } 131 132 /// Helper method to apply dimension ordering permutation. 133 static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d) { 134 if (enc) { 135 auto order = enc.getDimOrdering(); 136 if (order) { 137 assert(order.isPermutation()); 138 return order.getDimPosition(d); 139 } 140 } 141 return d; 142 } 143 144 /// Helper method to translate dim level type to internal representation. 145 static Dim toDim(const SparseTensorEncodingAttr &enc, unsigned d) { 146 if (enc) { 147 SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d]; 148 if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed) 149 return Dim::kSparse; 150 if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton) 151 return Dim::kSingle; 152 } 153 return Dim::kDense; 154 } 155 156 /// Helper method to inspect affine expressions. Rejects cases where the 157 /// same index is used more than once. Also rejects affine expressions 158 /// that are not a direct index for annotated tensors. 159 // TODO: accept more affine cases for sparse tensors 160 static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a, Dim dim, 161 bool isDense) { 162 switch (a.getKind()) { 163 case AffineExprKind::DimId: { 164 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 165 if (!merger.isDim(tensor, idx, Dim::kUndef)) 166 return false; // used more than once 167 merger.setDim(tensor, idx, dim); 168 return true; 169 } 170 case AffineExprKind::Add: 171 case AffineExprKind::Mul: { 172 if (!isDense) 173 return false; 174 auto binOp = a.cast<AffineBinaryOpExpr>(); 175 return findAffine(merger, tensor, binOp.getLHS(), dim, isDense) && 176 findAffine(merger, tensor, binOp.getRHS(), dim, isDense); 177 } 178 case AffineExprKind::Constant: 179 return isDense; 180 default: 181 return false; 182 } 183 } 184 185 /// Helper method to inspect sparse encodings in the tensor types. 186 /// Fills the per-dimension sparsity information for all tensors. 187 /// Returns true if the sparse annotations and affine subscript 188 /// expressions of all tensors are admissable. Returns false if 189 /// no annotations are found or inadmissable constructs occur. 190 static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) { 191 bool annotated = false; 192 for (OpOperand *t : op.getInputAndOutputOperands()) { 193 auto map = op.getTiedIndexingMap(t); 194 auto enc = getSparseTensorEncoding(t->get().getType()); 195 if (enc) 196 annotated = true; 197 assert(map.getNumResults() == op.getRank(t)); 198 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 199 unsigned tensor = t->getOperandNumber(); 200 AffineExpr a = map.getResult(perm(enc, d)); 201 if (!findAffine(merger, tensor, a, toDim(enc, d), !enc)) 202 return false; // inadmissable affine expression 203 } 204 } 205 return annotated; 206 } 207 208 /// A DFS helper to compute a topological sort. Note that recursion is 209 /// bounded by the number of implicit loops, which is always small. 210 /// Returns false when a cycle is detected. 211 static bool topSortDFS(unsigned i, std::vector<unsigned> &visit, 212 std::vector<unsigned> &topSort, 213 std::vector<std::vector<bool>> &adjM) { 214 if (visit[i] != 0) 215 return visit[i] != 1; // 1 denotes cycle! 216 visit[i] = 1; 217 for (unsigned j = 0, e = visit.size(); j < e; j++) 218 if (adjM[i][j]) 219 if (!topSortDFS(j, visit, topSort, adjM)) 220 return false; 221 visit[i] = 2; 222 topSort.push_back(i); 223 return true; 224 } 225 226 /// Helper method to add all constraints from the indices in one affine 227 /// expression before all indices in the other affine expression. For 228 /// example i0+i1 < i2+i3+1 yields i0<i2, i0<i3, i1<i2, and i1<i3. 229 static void addAffineOrderings(std::vector<std::vector<bool>> &adjM, 230 AffineExpr a, AffineExpr b, unsigned fidx) { 231 switch (a.getKind()) { 232 case AffineExprKind::DimId: { 233 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 234 if (b) 235 addAffineOrderings(adjM, b, AffineExpr(), idx); 236 else 237 adjM[fidx][idx] = true; 238 break; 239 } 240 case AffineExprKind::Add: 241 case AffineExprKind::Mul: { 242 auto binOp = a.cast<AffineBinaryOpExpr>(); 243 addAffineOrderings(adjM, binOp.getLHS(), b, fidx); 244 addAffineOrderings(adjM, binOp.getRHS(), b, fidx); 245 break; 246 } 247 default: 248 break; 249 } 250 } 251 252 /// Computes a topologically sorted iteration graph for the linalg operation. 253 /// Ensures all tensors are visited in natural index order. This is essential 254 /// for sparse storage formats since these only support access along fixed 255 /// dimensions. Even for dense storage formats, however, the natural index 256 /// order yields innermost unit-stride access with better spatial locality. 257 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op, 258 std::vector<unsigned> &topSort, unsigned mask, 259 OpOperand *skip = nullptr) { 260 // Set up an n x n from/to adjacency matrix of the iteration graph 261 // for the implicit loop indices i_0 .. i_n-1. 262 unsigned n = op.getNumLoops(); 263 std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false)); 264 265 // Iterate over the indexing maps of every tensor in the tensor expression. 266 for (OpOperand *t : op.getInputAndOutputOperands()) { 267 // Skip tensor during cycle resolution. 268 if (t == skip) 269 continue; 270 // Get map and encoding. 271 auto map = op.getTiedIndexingMap(t); 272 auto enc = getSparseTensorEncoding(t->get().getType()); 273 assert(map.getNumDims() == n); 274 // Skip dense tensor constraints when not requested. 275 if (!(mask & SortMask::kIncludeDense) && !enc) 276 continue; 277 // Each tensor expression and optional dimension ordering (row-major 278 // by default) puts an ordering constraint on the loop indices. For 279 // example, the tensor expresion A_ijk forces the ordering i < j < k 280 // on the loop indices if no explicit dimension ordering is given. 281 for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) { 282 AffineExpr f = map.getResult(perm(enc, d - 1)); 283 AffineExpr t = map.getResult(perm(enc, d)); 284 addAffineOrderings(adjM, f, t, 0); 285 } 286 // Push unrelated loops into sparse iteration space, so these 287 // will be skipped more often. 288 if (mask & SortMask::kIncludeUndef) { 289 unsigned tensor = t->getOperandNumber(); 290 for (unsigned i = 0; i < n; i++) 291 if (merger.isDim(tensor, i, Dim::kSparse)) 292 for (unsigned j = 0; j < n; j++) 293 if (merger.isDim(tensor, j, Dim::kUndef)) 294 adjM[i][j] = true; 295 } 296 } 297 298 // Topologically sort the iteration graph to determine loop order. 299 // Report failure for a cyclic iteration graph. 300 topSort.clear(); 301 topSort.reserve(n); 302 std::vector<unsigned> visit(n, 0); 303 for (unsigned i = 0; i < n; i++) 304 if (visit[i] == 0) 305 if (!topSortDFS(i, visit, topSort, adjM)) 306 return false; // cycle! 307 std::reverse(std::begin(topSort), std::end(topSort)); 308 return true; 309 } 310 311 /// Returns true if tensor has an in-place annotation. 312 static bool isInPlace(Value val) { 313 if (auto arg = val.dyn_cast<BlockArgument>()) 314 if (auto funcOp = dyn_cast<func::FuncOp>(arg.getOwner()->getParentOp())) 315 if (auto attr = funcOp.getArgAttrOfType<BoolAttr>( 316 arg.getArgNumber(), 317 bufferization::BufferizableOpInterface::kInplaceableAttrName)) 318 return attr.getValue(); 319 return false; 320 } 321 322 /// Returns true if tensor materializes uninitialized into the computation. 323 static bool isMaterializing(Value val) { 324 return val.getDefiningOp<linalg::InitTensorOp>() || 325 val.getDefiningOp<bufferization::AllocTensorOp>(); 326 } 327 328 /// Returns true when the tensor expression is admissable for codegen. 329 /// Since all sparse input tensors are admissable, we just need to check 330 /// whether the out tensor in the tensor expression codegen is admissable. 331 /// Sets `sparseOut` to the tensor and `outerParNest` to the outer injective 332 /// nesting depth when a "truly dynamic" sparse tensor output occurs. 333 static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op, 334 std::vector<unsigned> &topSort, unsigned exp, 335 OpOperand **sparseOut, 336 unsigned &outerParNest) { 337 OpOperand *lhs = op.getOutputOperand(0); 338 unsigned tensor = lhs->getOperandNumber(); 339 auto enc = getSparseTensorEncoding(lhs->get().getType()); 340 // An non-annotated output tensor is assumed dense, and becomes a random 341 // access n-dim memref. Admissable since insertions cannot occur. 342 if (!enc) 343 return true; 344 // An all-dense annotated "sparse" output tensor becomes a linearized random 345 // access 1-dim memref. Also admissable since insertions cannot occur. 346 bool allDense = true; 347 auto iteratorTypes = op.iterator_types().getValue(); 348 unsigned numLoops = iteratorTypes.size(); 349 for (unsigned i = 0; i < numLoops; i++) 350 if (merger.isDim(tensor, i, Dim::kSparse)) { 351 allDense = false; 352 break; 353 } 354 if (allDense) 355 return true; 356 // A tensor expression with a sparse output tensor that changes its values 357 // but not its nonzero structure, an operation called "simply dynamic" in 358 // [Bik96,Ch9], is also admissable without special codegen, provided 359 // the tensor's underlying sparse storage scheme can be modified in-place. 360 if (merger.isSingleCondition(tensor, exp) && isInPlace(lhs->get())) 361 return true; 362 // Accept "truly dynamic" if the output tensor materializes uninitialized 363 // into the computation and insertions occur in lexicographic index order. 364 if (isMaterializing(lhs->get())) { 365 unsigned nest = 0; 366 for (unsigned i = 0; i < numLoops; i++) { 367 if (isReductionIterator(iteratorTypes[topSort[i]])) 368 break; // terminate at first reduction 369 nest++; 370 } 371 // Determine admissable dynamic insertion situations: 372 // (1) fully injective, since there are no reductions, 373 // (2) admissable 1-d expansion in innermost dimension. 374 if (nest >= op.getRank(lhs) - 1) { 375 *sparseOut = lhs; 376 outerParNest = nest; 377 return true; 378 } 379 } 380 return false; 381 } 382 383 //===----------------------------------------------------------------------===// 384 // Sparse compiler synthesis methods (reductions). 385 //===----------------------------------------------------------------------===// 386 387 /// Maps reduction kind to vector::CombiningKind. 388 static vector::CombiningKind getCombiningKind(Reduction kind) { 389 switch (kind) { 390 case kNoReduc: 391 break; 392 case kSum: 393 return vector::CombiningKind::ADD; 394 case kProduct: 395 return vector::CombiningKind::MUL; 396 case kAnd: 397 return vector::CombiningKind::AND; 398 case kOr: 399 return vector::CombiningKind::OR; 400 case kXor: 401 return vector::CombiningKind::XOR; 402 } 403 llvm_unreachable("unknown reduction kind"); 404 } 405 406 /// Maps operation to reduction. 407 static Reduction getReduction(Kind kind) { 408 switch (kind) { 409 case Kind::kAddF: 410 case Kind::kAddC: 411 case Kind::kAddI: 412 case Kind::kSubF: 413 case Kind::kSubC: 414 case Kind::kSubI: 415 return kSum; 416 case Kind::kMulF: 417 case Kind::kMulC: 418 case Kind::kMulI: 419 return kProduct; 420 case Kind::kAndI: 421 return kAnd; 422 case Kind::kOrI: 423 return kOr; 424 case Kind::kXorI: 425 return kXor; 426 default: 427 llvm_unreachable("unexpected reduction operator"); 428 } 429 } 430 431 /// Generates an initial value for a vector reduction, following the scheme 432 /// given in Chapter 5 of "The Software Vectorization Handbook", where the 433 /// initial scalar value is correctly embedded in the vector reduction value, 434 /// and a straightforward horizontal reduction will complete the operation. 435 static Value genVectorReducInit(CodeGen &codegen, OpBuilder &builder, 436 Location loc, VectorType vtp) { 437 Value r = codegen.redVal; 438 switch (codegen.redKind) { 439 case kNoReduc: 440 break; 441 case kSum: 442 case kXor: 443 // Initialize reduction vector to: | 0 | .. | 0 | r | 444 return builder.create<vector::InsertElementOp>( 445 loc, r, constantZero(builder, loc, vtp), 446 constantIndex(builder, loc, 0)); 447 case kProduct: 448 // Initialize reduction vector to: | 1 | .. | 1 | r | 449 return builder.create<vector::InsertElementOp>( 450 loc, r, constantOne(builder, loc, vtp), constantIndex(builder, loc, 0)); 451 case kAnd: 452 case kOr: 453 // Initialize reduction vector to: | r | .. | r | r | 454 return builder.create<vector::BroadcastOp>(loc, vtp, r); 455 } 456 llvm_unreachable("unknown reduction kind"); 457 } 458 459 /// Generates final value for a vector reduction. 460 static Value genVectorReducEnd(CodeGen &codegen, OpBuilder &builder, 461 Location loc, VectorType vtp) { 462 vector::CombiningKind kind = getCombiningKind(codegen.redKind); 463 return builder.create<vector::ReductionOp>(loc, kind, codegen.redVal); 464 } 465 466 /// Updates scalarized reduction value. 467 static void updateReduc(Merger &merger, CodeGen &codegen, Value reduc) { 468 assert(codegen.redKind != kNoReduc); 469 codegen.redVal = merger.exp(codegen.redExp).val = reduc; 470 } 471 472 //===----------------------------------------------------------------------===// 473 // Sparse compiler synthesis methods (statements and expressions). 474 //===----------------------------------------------------------------------===// 475 476 /// Generates buffer for the output tensor. Note that all sparse kernels 477 /// assume that when all elements are written to (viz. x(i) = y(i) * z(i)), 478 /// the output buffer is already initialized to all zeroes and only nonzeroes 479 /// values are computed and written out. For updates (viz. x(i) += y(i) * z(i)), 480 /// only nonzeroes values are used for the updates and no assumption on the 481 /// original contents of the output buffer is necessary. 482 static Value genOutputBuffer(CodeGen &codegen, OpBuilder &builder, 483 linalg::GenericOp op, MemRefType denseTp, 484 ArrayRef<Value> args) { 485 Location loc = op.getLoc(); 486 OpOperand *lhs = op.getOutputOperand(0); 487 Value tensor = lhs->get(); 488 bool isInit = op.isInitTensor(lhs); 489 // An output tensor that is in-place can simply materialize from the buffer 490 // of the tensor that appears in the outs() clause. For updates, this has 491 // the advantage that only the nonzero value are involved in the computation, 492 // keeping the operation O(nnz). In all other cases, we are forced to zero 493 // out the buffer to enforce the assumption above, which may negatively 494 // impact running complexity (viz. O(n^2 + nnz) vs. O(nnz) for matrices). 495 // TODO: use better analysis to avoid zeroing out the buffer? 496 if (isInPlace(tensor)) { 497 Value init = 498 builder.create<bufferization::ToMemrefOp>(loc, denseTp, tensor); 499 if (!isInit) { 500 Value zero = constantZero(builder, loc, denseTp.getElementType()); 501 builder.create<linalg::FillOp>(loc, ValueRange{zero}, ValueRange{init}); 502 } 503 return init; 504 } 505 // By default, a new buffer is allocated which is either set to zero (when 506 // no updates occur or the tensor materializes into this computation) or 507 // initialized to the value of the tensor defined in the outs() clause. 508 // This is always correct (since it enforces all assumptions above) but 509 // may negatively impact running complexity as explained above. 510 Value alloc = builder.create<memref::AllocOp>(loc, denseTp, args); 511 if (!isInit || isMaterializing(tensor)) { 512 Value zero = constantZero(builder, loc, denseTp.getElementType()); 513 builder.create<linalg::FillOp>(loc, ValueRange{zero}, ValueRange{alloc}); 514 } else { 515 Value init = 516 builder.create<bufferization::ToMemrefOp>(loc, denseTp, tensor); 517 builder.create<memref::CopyOp>(loc, init, alloc); 518 } 519 return alloc; 520 } 521 522 /// Local bufferization of all dense and sparse data structures. 523 /// This code enables testing the first prototype sparse compiler. 524 // TODO: replace this with a proliferated bufferization strategy 525 static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder, 526 linalg::GenericOp op) { 527 Location loc = op.getLoc(); 528 assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1); 529 // For every tensor, find lower and upper bound on dimensions, set the 530 // same bounds on loop indices, and obtain dense or sparse buffer(s). 531 SmallVector<Value, 4> args; 532 for (OpOperand *t : op.getInputAndOutputOperands()) { 533 unsigned tensor = t->getOperandNumber(); 534 auto shape = op.getShape(t); 535 auto map = op.getTiedIndexingMap(t); 536 auto enc = getSparseTensorEncoding(t->get().getType()); 537 // Scan all dimensions of current tensor. 538 args.clear(); 539 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 540 AffineExpr a = map.getResult(perm(enc, d)); 541 if (a.getKind() != AffineExprKind::DimId) 542 continue; // compound 543 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 544 // Handle sparse storage schemes. 545 if (merger.isDim(tensor, idx, Dim::kSparse)) { 546 auto dynShape = {ShapedType::kDynamicSize}; 547 auto ptrTp = 548 MemRefType::get(dynShape, getPointerOverheadType(builder, enc)); 549 auto indTp = 550 MemRefType::get(dynShape, getIndexOverheadType(builder, enc)); 551 Value dim = constantIndex(builder, loc, d); 552 // Generate sparse primitives to obtains pointer and indices. 553 codegen.pointers[tensor][idx] = 554 builder.create<ToPointersOp>(loc, ptrTp, t->get(), dim); 555 codegen.indices[tensor][idx] = 556 builder.create<ToIndicesOp>(loc, indTp, t->get(), dim); 557 } 558 // Find upper bound in current dimension. 559 unsigned p = perm(enc, d); 560 Value up = linalg::createOrFoldDimOp(builder, loc, t->get(), p); 561 if (ShapedType::isDynamic(shape[p])) 562 args.push_back(up); 563 assert(codegen.highs[tensor][idx] == nullptr); 564 codegen.sizes[idx] = codegen.highs[tensor][idx] = up; 565 } 566 // Perform the required bufferization. Dense inputs materialize 567 // from the input tensors. Dense outputs need special handling. 568 // Sparse inputs use sparse primitives to obtain the values. 569 // We also accept in-place all-dense annotated "sparse" outputs. 570 Type elementType = getElementTypeOrSelf(t->get().getType()); 571 if (!enc) { 572 // Non-annotated dense tensors. 573 auto denseTp = MemRefType::get(shape, elementType); 574 if (tensor < op.getNumInputs()) 575 codegen.buffers[tensor] = 576 builder.create<bufferization::ToMemrefOp>(loc, denseTp, t->get()); 577 else 578 codegen.buffers[tensor] = 579 genOutputBuffer(codegen, builder, op, denseTp, args); 580 } else if (t == codegen.sparseOut) { 581 // True sparse output needs a lexIdx array. 582 Value rank = constantIndex(builder, loc, op.getRank(t)); 583 auto dynShape = {ShapedType::kDynamicSize}; 584 auto memTp = MemRefType::get(dynShape, builder.getIndexType()); 585 codegen.lexIdx = builder.create<memref::AllocaOp>(loc, memTp, rank); 586 codegen.lexVal = builder.create<memref::AllocaOp>( 587 loc, MemRefType::get({}, elementType)); 588 } else { 589 // Annotated sparse tensors. 590 auto dynShape = {ShapedType::kDynamicSize}; 591 auto sparseTp = MemRefType::get(dynShape, elementType); 592 codegen.buffers[tensor] = 593 builder.create<ToValuesOp>(loc, sparseTp, t->get()); 594 } 595 } 596 } 597 598 /// Constructs vector type. 599 static VectorType vectorType(CodeGen &codegen, Type etp) { 600 unsigned numScalableDims = codegen.options.enableVLAVectorization; 601 return VectorType::get(codegen.curVecLength, etp, numScalableDims); 602 } 603 604 /// Constructs vector type from pointer. 605 static VectorType vectorType(CodeGen &codegen, Value ptr) { 606 return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType()); 607 } 608 609 /// Constructs vector iteration mask. 610 static Value genVectorMask(CodeGen &codegen, OpBuilder &builder, Value iv, 611 Value lo, Value hi, Value step) { 612 Location loc = iv.getLoc(); 613 VectorType mtp = vectorType(codegen, builder.getI1Type()); 614 // Special case if the vector length evenly divides the trip count (for 615 // example, "for i = 0, 128, 16"). A constant all-true mask is generated 616 // so that all subsequent masked memory operations are immediately folded 617 // into unconditional memory operations. 618 IntegerAttr loInt, hiInt, stepInt; 619 if (matchPattern(lo, m_Constant(&loInt)) && 620 matchPattern(hi, m_Constant(&hiInt)) && 621 matchPattern(step, m_Constant(&stepInt))) { 622 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) 623 return builder.create<vector::BroadcastOp>( 624 loc, mtp, constantI1(builder, loc, true)); 625 } 626 // Otherwise, generate a vector mask that avoids overrunning the upperbound 627 // during vector execution. Here we rely on subsequent loop optimizations to 628 // avoid executing the mask in all iterations, for example, by splitting the 629 // loop into an unconditional vector loop and a scalar cleanup loop. 630 auto minMap = AffineMap::get( 631 /*dimCount=*/2, /*symbolCount=*/1, 632 {builder.getAffineSymbolExpr(0), 633 builder.getAffineDimExpr(0) - builder.getAffineDimExpr(1)}, 634 builder.getContext()); 635 Value end = 636 builder.createOrFold<AffineMinOp>(loc, minMap, ValueRange{hi, iv, step}); 637 return builder.create<vector::CreateMaskOp>(loc, mtp, end); 638 } 639 640 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi]. 641 static Value genVectorLoad(CodeGen &codegen, OpBuilder &builder, Value ptr, 642 ArrayRef<Value> args) { 643 Location loc = ptr.getLoc(); 644 VectorType vtp = vectorType(codegen, ptr); 645 Value pass = constantZero(builder, loc, vtp); 646 if (args.back().getType().isa<VectorType>()) { 647 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 648 Value indexVec = args.back(); 649 scalarArgs.back() = constantIndex(builder, loc, 0); 650 return builder.create<vector::GatherOp>(loc, vtp, ptr, scalarArgs, indexVec, 651 codegen.curVecMask, pass); 652 } 653 return builder.create<vector::MaskedLoadOp>(loc, vtp, ptr, args, 654 codegen.curVecMask, pass); 655 } 656 657 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs. 658 static void genVectorStore(CodeGen &codegen, OpBuilder &builder, Value rhs, 659 Value ptr, ArrayRef<Value> args) { 660 Location loc = ptr.getLoc(); 661 if (args.back().getType().isa<VectorType>()) { 662 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 663 Value indexVec = args.back(); 664 scalarArgs.back() = constantIndex(builder, loc, 0); 665 builder.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, 666 codegen.curVecMask, rhs); 667 return; 668 } 669 builder.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask, 670 rhs); 671 } 672 673 /// Generates a vectorized invariant. Here we rely on subsequent loop 674 /// optimizations to hoist the invariant broadcast out of the vector loop. 675 static Value genVectorInvariantValue(CodeGen &codegen, OpBuilder &builder, 676 Value val) { 677 VectorType vtp = vectorType(codegen, val.getType()); 678 return builder.create<vector::BroadcastOp>(val.getLoc(), vtp, val); 679 } 680 681 /// Generates an affine expression. 682 // 683 // TODO: generalize for sparse tensor subscripts 684 // 685 static Value genAffine(CodeGen &codegen, OpBuilder &builder, AffineExpr a, 686 Location loc) { 687 switch (a.getKind()) { 688 case AffineExprKind::DimId: { 689 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 690 return codegen.loops[idx]; // universal dense index 691 } 692 case AffineExprKind::Add: { 693 auto binOp = a.cast<AffineBinaryOpExpr>(); 694 return builder.create<arith::AddIOp>( 695 loc, genAffine(codegen, builder, binOp.getLHS(), loc), 696 genAffine(codegen, builder, binOp.getRHS(), loc)); 697 } 698 case AffineExprKind::Mul: { 699 auto binOp = a.cast<AffineBinaryOpExpr>(); 700 return builder.create<arith::MulIOp>( 701 loc, genAffine(codegen, builder, binOp.getLHS(), loc), 702 genAffine(codegen, builder, binOp.getRHS(), loc)); 703 } 704 case AffineExprKind::Constant: { 705 int64_t c = a.cast<AffineConstantExpr>().getValue(); 706 return constantIndex(builder, loc, c); 707 } 708 default: 709 llvm_unreachable("unexpected affine subscript"); 710 } 711 } 712 713 /// Generates index for load/store on sparse tensor. 714 static Value genIndex(CodeGen &codegen, linalg::GenericOp op, OpOperand *t) { 715 auto map = op.getTiedIndexingMap(t); 716 auto enc = getSparseTensorEncoding(t->get().getType()); 717 AffineExpr a = map.getResult(perm(enc, map.getNumResults() - 1)); 718 assert(a.getKind() == AffineExprKind::DimId); 719 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 720 return codegen.loops[idx]; 721 } 722 723 /// Generates subscript for load/store on a dense or sparse tensor. 724 static Value genSubscript(CodeGen &codegen, OpBuilder &builder, 725 linalg::GenericOp op, OpOperand *t, 726 SmallVector<Value, 4> &args) { 727 unsigned tensor = t->getOperandNumber(); 728 auto map = op.getTiedIndexingMap(t); 729 auto enc = getSparseTensorEncoding(t->get().getType()); 730 unsigned rank = map.getNumResults(); 731 if (enc) { 732 // Note that currently, all sparse subscripts are simple. 733 // TODO: accept affine too? 734 AffineExpr a = map.getResult(perm(enc, rank - 1)); 735 assert(a.getKind() == AffineExprKind::DimId); 736 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 737 assert(codegen.pidxs[tensor][idx] != nullptr); 738 args.push_back(codegen.pidxs[tensor][idx]); // position index 739 } else { 740 for (unsigned d = 0; d < rank; d++) { 741 AffineExpr a = map.getResult(perm(enc, d)); 742 args.push_back(genAffine(codegen, builder, a, op.getLoc())); 743 } 744 } 745 return codegen.buffers[tensor]; 746 } 747 748 /// Generates insertion code to implement dynamic tensor load. 749 static Value genInsertionLoad(CodeGen &codegen, OpBuilder &builder, 750 linalg::GenericOp op, OpOperand *t) { 751 Location loc = op.getLoc(); 752 // Direct lexicographic index order, tensor loads as zero. 753 if (!codegen.expValues) { 754 Type tp = getElementTypeOrSelf(t->get().getType()); 755 return constantZero(builder, loc, tp); 756 } 757 // Load from expanded access pattern. 758 Value index = genIndex(codegen, op, t); 759 return builder.create<memref::LoadOp>(loc, codegen.expValues, index); 760 } 761 762 /// Generates insertion code to implement dynamic tensor store. 763 static void genInsertionStore(CodeGen &codegen, OpBuilder &builder, 764 linalg::GenericOp op, OpOperand *t, Value rhs) { 765 Location loc = op.getLoc(); 766 // Direct insertion in lexicographic index order. 767 if (!codegen.expValues) { 768 builder.create<memref::StoreOp>(loc, rhs, codegen.lexVal); 769 builder.create<LexInsertOp>(loc, t->get(), codegen.lexIdx, codegen.lexVal); 770 return; 771 } 772 // Generates insertion code along expanded access pattern. 773 // if (!expFilled[i]) then 774 // expFilled[i] = true 775 // expAdded[inserts++] = i 776 // endif 777 // values[i] = rhs 778 Value index = genIndex(codegen, op, t); 779 Value fval = constantI1(builder, loc, false); 780 Value tval = constantI1(builder, loc, true); 781 // If statement. 782 Value filled = builder.create<memref::LoadOp>(loc, codegen.expFilled, index); 783 Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 784 filled, fval); 785 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIndexType(), cond, 786 /*else=*/true); 787 // True branch. 788 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 789 builder.create<memref::StoreOp>(loc, tval, codegen.expFilled, index); 790 builder.create<memref::StoreOp>(loc, index, codegen.expAdded, 791 codegen.expCount); 792 Value one = constantIndex(builder, loc, 1); 793 Value add = builder.create<arith::AddIOp>(loc, codegen.expCount, one); 794 builder.create<scf::YieldOp>(loc, add); 795 // False branch. 796 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 797 builder.create<scf::YieldOp>(loc, codegen.expCount); 798 builder.setInsertionPointAfter(ifOp); 799 // Value assignment. 800 codegen.expCount = ifOp.getResult(0); 801 builder.create<memref::StoreOp>(loc, rhs, codegen.expValues, index); 802 } 803 804 /// Generates a load on a dense or sparse tensor. 805 static Value genTensorLoad(Merger &merger, CodeGen &codegen, OpBuilder &builder, 806 linalg::GenericOp op, unsigned exp) { 807 // Test if the load was hoisted to a higher loop nest. 808 Value val = merger.exp(exp).val; 809 if (val) { 810 if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>()) 811 return genVectorInvariantValue(codegen, builder, val); 812 return val; 813 } 814 // Load during insertion. 815 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 816 if (t == codegen.sparseOut) 817 return genInsertionLoad(codegen, builder, op, t); 818 // Actual load. 819 SmallVector<Value, 4> args; 820 Value ptr = genSubscript(codegen, builder, op, t, args); 821 if (codegen.curVecLength > 1) 822 return genVectorLoad(codegen, builder, ptr, args); 823 return builder.create<memref::LoadOp>(op.getLoc(), ptr, args); 824 } 825 826 /// Generates a store on a dense or sparse tensor. 827 static void genTensorStore(Merger &merger, CodeGen &codegen, OpBuilder &builder, 828 linalg::GenericOp op, unsigned exp, Value rhs) { 829 Location loc = op.getLoc(); 830 // Test if this is a scalarized reduction. 831 if (codegen.redVal) { 832 if (codegen.curVecLength > 1) 833 rhs = builder.create<arith::SelectOp>(loc, codegen.curVecMask, rhs, 834 codegen.redVal); 835 updateReduc(merger, codegen, rhs); 836 return; 837 } 838 // Store during insertion. 839 OpOperand *t = op.getOutputOperand(0); 840 if (t == codegen.sparseOut) { 841 if (!rhs) { 842 // Only unary and binary are allowed to return uninitialized rhs 843 // to indicate missing output. 844 assert(merger.exp(exp).kind == kUnary || merger.exp(exp).kind == kBinary); 845 } else { 846 genInsertionStore(codegen, builder, op, t, rhs); 847 } 848 return; 849 } 850 // Actual store. 851 SmallVector<Value, 4> args; 852 Value ptr = genSubscript(codegen, builder, op, t, args); 853 if (codegen.curVecLength > 1) 854 genVectorStore(codegen, builder, rhs, ptr, args); 855 else 856 builder.create<memref::StoreOp>(loc, rhs, ptr, args); 857 } 858 859 /// Generates a pointer/index load from the sparse storage scheme. Narrower 860 /// data types need to be zero extended before casting the value into the 861 /// index type used for looping and indexing. 862 static Value genLoad(CodeGen &codegen, OpBuilder &builder, Location loc, 863 Value ptr, Value s) { 864 // See https://llvm.org/docs/GetElementPtr.html for some background on 865 // the complications described below. 866 if (codegen.curVecLength > 1) { 867 // Since the index vector is used in a subsequent gather/scatter operations, 868 // which effectively defines an unsigned pointer + signed index, we must 869 // zero extend the vector to an index width. For 8-bit and 16-bit values, 870 // an 32-bit index width suffices. For 32-bit values, zero extending the 871 // elements into 64-bit loses some performance since the 32-bit indexed 872 // gather/scatter is more efficient than the 64-bit index variant (if the 873 // negative 32-bit index space is unused, the enableSIMDIndex32 flag can 874 // preserve this performance). For 64-bit values, there is no good way 875 // to state that the indices are unsigned, with creates the potential of 876 // incorrect address calculations in the unlikely case we need such 877 // extremely large offsets. 878 Type etp = ptr.getType().cast<MemRefType>().getElementType(); 879 Value vload = genVectorLoad(codegen, builder, ptr, {s}); 880 if (!etp.isa<IndexType>()) { 881 if (etp.getIntOrFloatBitWidth() < 32) 882 vload = builder.create<arith::ExtUIOp>( 883 loc, vectorType(codegen, builder.getI32Type()), vload); 884 else if (etp.getIntOrFloatBitWidth() < 64 && 885 !codegen.options.enableSIMDIndex32) 886 vload = builder.create<arith::ExtUIOp>( 887 loc, vectorType(codegen, builder.getI64Type()), vload); 888 } 889 return vload; 890 } 891 // For the scalar case, we simply zero extend narrower indices into 64-bit 892 // values before casting to index without a performance penalty. Here too, 893 // however, indices that already are 64-bit, in theory, cannot express the 894 // full range as explained above. 895 Value load = builder.create<memref::LoadOp>(loc, ptr, s); 896 if (!load.getType().isa<IndexType>()) { 897 if (load.getType().getIntOrFloatBitWidth() < 64) 898 load = builder.create<arith::ExtUIOp>(loc, builder.getI64Type(), load); 899 load = 900 builder.create<arith::IndexCastOp>(loc, builder.getIndexType(), load); 901 } 902 return load; 903 } 904 905 /// Generates an invariant value. 906 static Value genInvariantValue(Merger &merger, CodeGen &codegen, 907 OpBuilder &builder, unsigned exp) { 908 Value val = merger.exp(exp).val; 909 if (codegen.curVecLength > 1) 910 return genVectorInvariantValue(codegen, builder, val); 911 return val; 912 } 913 914 /// Generates an address computation "sz * p + i". 915 static Value genAddress(CodeGen &codegen, OpBuilder &builder, Location loc, 916 Value size, Value p, Value i) { 917 Value mul = builder.create<arith::MulIOp>(loc, size, p); 918 if (auto vtp = i.getType().dyn_cast<VectorType>()) { 919 Value inv = 920 builder.create<arith::IndexCastOp>(loc, vtp.getElementType(), mul); 921 mul = genVectorInvariantValue(codegen, builder, inv); 922 } 923 return builder.create<arith::AddIOp>(loc, mul, i); 924 } 925 926 /// Generates an index value. 927 static Value genIndexValue(CodeGen &codegen, OpBuilder &builder, unsigned idx, 928 unsigned ldx) { 929 Value ival = codegen.loops[idx]; 930 Type itype = ival.getType(); 931 // During vectorization, we either encounter: 932 // (1) indices already in vector form, as in ... = ind[lo:hi], good to go, or 933 // (2) single index, as in ... = i, must convert to [i, i+1, ...] for inner i. 934 unsigned vl = codegen.curVecLength; 935 if (vl > 1 && !itype.isa<VectorType>()) { 936 Location loc = ival.getLoc(); 937 VectorType vtp = vectorType(codegen, itype); 938 ival = builder.create<vector::BroadcastOp>(loc, vtp, ival); 939 if (idx == ldx) { 940 Value incr; 941 if (vtp.isScalable()) { 942 Type stepvty = vectorType(codegen, builder.getI64Type()); 943 Value stepv = builder.create<LLVM::StepVectorOp>(loc, stepvty); 944 incr = builder.create<arith::IndexCastOp>(loc, vtp, stepv); 945 } else { 946 SmallVector<APInt, 4> integers; 947 for (unsigned i = 0; i < vl; i++) 948 integers.push_back(APInt(/*width=*/64, i)); 949 auto values = DenseElementsAttr::get(vtp, integers); 950 incr = builder.create<arith::ConstantOp>(loc, vtp, values); 951 } 952 ival = builder.create<arith::AddIOp>(loc, ival, incr); 953 } 954 } 955 return ival; 956 } 957 958 /// Semi-ring branches are simply inlined by the sparse compiler. Prior 959 /// analysis has verified that all computations are "local" to the inlined 960 /// branch or otherwise invariantly defined outside the loop nest, with the 961 /// exception of index computations, which need to be relinked to actual 962 /// inlined cloned code. 963 static Value relinkBranch(CodeGen &codegen, RewriterBase &rewriter, 964 Block *block, Value e, unsigned ldx) { 965 if (Operation *def = e.getDefiningOp()) { 966 if (auto indexOp = dyn_cast<linalg::IndexOp>(def)) 967 return genIndexValue(codegen, rewriter, indexOp.dim(), ldx); 968 if (def->getBlock() == block) { 969 for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) 970 def->setOperand( 971 i, relinkBranch(codegen, rewriter, block, def->getOperand(i), ldx)); 972 } 973 } 974 return e; 975 } 976 977 /// Recursively generates tensor expression. 978 static Value genExp(Merger &merger, CodeGen &codegen, RewriterBase &rewriter, 979 linalg::GenericOp op, unsigned exp, unsigned ldx) { 980 Location loc = op.getLoc(); 981 if (exp == -1u) 982 return Value(); 983 if (merger.exp(exp).kind == Kind::kTensor) 984 return genTensorLoad(merger, codegen, rewriter, op, exp); 985 if (merger.exp(exp).kind == Kind::kInvariant) 986 return genInvariantValue(merger, codegen, rewriter, exp); 987 if (merger.exp(exp).kind == Kind::kIndex) 988 return genIndexValue(codegen, rewriter, merger.exp(exp).index, ldx); 989 Value v0 = 990 genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0, ldx); 991 Value v1 = 992 genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1, ldx); 993 Value ee = merger.buildExp(rewriter, loc, exp, v0, v1); 994 if (ee && (merger.exp(exp).kind == Kind::kUnary || 995 merger.exp(exp).kind == Kind::kBinary || 996 merger.exp(exp).kind == Kind::kBinaryBranch)) 997 ee = relinkBranch(codegen, rewriter, ee.getParentBlock(), ee, ldx); 998 return ee; 999 } 1000 1001 /// Determines if affine expression is invariant. 1002 static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a, 1003 unsigned ldx, bool &atLevel) { 1004 switch (a.getKind()) { 1005 case AffineExprKind::DimId: { 1006 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 1007 if (idx == ldx) 1008 atLevel = true; 1009 return codegen.loops[idx] != nullptr; // no longer in play? 1010 } 1011 case AffineExprKind::Add: 1012 case AffineExprKind::Mul: { 1013 auto binOp = a.cast<AffineBinaryOpExpr>(); 1014 return isInvariantAffine(codegen, binOp.getLHS(), ldx, atLevel) && 1015 isInvariantAffine(codegen, binOp.getRHS(), ldx, atLevel); 1016 } 1017 default: 1018 return true; 1019 } 1020 } 1021 1022 /// Hoists loop invariant tensor loads for which indices have been exhausted. 1023 static void genInvariants(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1024 linalg::GenericOp op, unsigned exp, unsigned ldx, 1025 bool atStart, Kind last = Kind::kTensor) { 1026 if (exp == -1u) 1027 return; 1028 if (merger.exp(exp).kind == Kind::kTensor) { 1029 // Inspect tensor indices. 1030 bool atLevel = ldx == -1u; 1031 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 1032 auto map = op.getTiedIndexingMap(t); 1033 auto enc = getSparseTensorEncoding(t->get().getType()); 1034 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 1035 AffineExpr a = map.getResult(perm(enc, d)); 1036 if (!isInvariantAffine(codegen, a, ldx, atLevel)) 1037 return; // still in play 1038 } 1039 // All exhausted at this level (atLevel denotes exactly at this level). 1040 if (!atLevel) 1041 return; 1042 OpOperand *lhs = op.getOutputOperand(0); 1043 if (lhs == t) { 1044 // Start or end a scalarized reduction 1045 if (atStart) { 1046 Value load = genTensorLoad(merger, codegen, builder, op, exp); 1047 codegen.redKind = getReduction(last); 1048 codegen.redExp = exp; 1049 updateReduc(merger, codegen, load); 1050 } else { 1051 Value redVal = codegen.redVal; 1052 updateReduc(merger, codegen, Value()); 1053 codegen.redExp = -1u; 1054 codegen.redKind = kNoReduc; 1055 genTensorStore(merger, codegen, builder, op, exp, redVal); 1056 } 1057 } else { 1058 // Start or end loop invariant hoisting of a tensor load. 1059 merger.exp(exp).val = 1060 atStart ? genTensorLoad(merger, codegen, builder, op, exp) : Value(); 1061 } 1062 } else if (merger.exp(exp).kind != Kind::kInvariant && 1063 merger.exp(exp).kind != Kind::kIndex) { 1064 // Traverse into the binary operations. Note that we only hoist 1065 // tensor loads, since subsequent MLIR/LLVM passes know how to 1066 // deal with all other kinds of derived loop invariants. 1067 Kind last = merger.exp(exp).kind; 1068 unsigned e0 = merger.exp(exp).children.e0; 1069 unsigned e1 = merger.exp(exp).children.e1; 1070 genInvariants(merger, codegen, builder, op, e0, ldx, atStart, last); 1071 genInvariants(merger, codegen, builder, op, e1, ldx, atStart, last); 1072 } 1073 } 1074 1075 /// Generates an expanded access pattern in innermost dimension. 1076 static void genExpansion(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1077 linalg::GenericOp op, unsigned at, bool atStart) { 1078 OpOperand *lhs = codegen.sparseOut; 1079 if (!lhs || codegen.outerParNest != op.getRank(lhs) - 1 || 1080 at != codegen.outerParNest) 1081 return; // not needed at this level 1082 // Generate start or end of an expanded access pattern. 1083 Value tensor = lhs->get(); 1084 Location loc = op.getLoc(); 1085 if (atStart) { 1086 auto dynShape = {ShapedType::kDynamicSize}; 1087 Type etp = tensor.getType().cast<ShapedType>().getElementType(); 1088 Type t1 = MemRefType::get(dynShape, etp); 1089 Type t2 = MemRefType::get(dynShape, builder.getI1Type()); 1090 Type t3 = MemRefType::get(dynShape, builder.getIndexType()); 1091 Type t4 = builder.getIndexType(); 1092 auto res = 1093 builder.create<ExpandOp>(loc, TypeRange({t1, t2, t3, t4}), tensor); 1094 assert(res.getNumResults() == 4); 1095 assert(!codegen.expValues); 1096 codegen.expValues = res.getResult(0); 1097 codegen.expFilled = res.getResult(1); 1098 codegen.expAdded = res.getResult(2); 1099 codegen.expCount = res.getResult(3); 1100 } else { 1101 assert(codegen.expValues); 1102 builder.create<CompressOp>(loc, tensor, codegen.lexIdx, codegen.expValues, 1103 codegen.expFilled, codegen.expAdded, 1104 codegen.expCount); 1105 codegen.expValues = codegen.expFilled = codegen.expAdded = 1106 codegen.expCount = Value(); 1107 } 1108 } 1109 1110 /// Generates initialization code for the subsequent loop sequence at 1111 /// current index level. Returns true if the loop sequence needs to 1112 /// maintain the universal index. 1113 static bool genInit(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1114 linalg::GenericOp op, std::vector<unsigned> &topSort, 1115 unsigned at, BitVector &inits) { 1116 bool needsUniv = false; 1117 Location loc = op.getLoc(); 1118 unsigned idx = topSort[at]; 1119 1120 // Initialize sparse positions. 1121 for (unsigned b = 0, be = inits.size(); b < be; b++) { 1122 if (inits[b]) { 1123 unsigned tensor = merger.tensor(b); 1124 assert(idx == merger.index(b)); 1125 if (merger.isDim(b, Dim::kSparse)) { 1126 // Initialize sparse index. 1127 unsigned pat = at; 1128 for (; pat != 0; pat--) { 1129 if (codegen.pidxs[tensor][topSort[pat - 1]]) 1130 break; 1131 } 1132 Value ptr = codegen.pointers[tensor][idx]; 1133 Value one = constantIndex(builder, loc, 1); 1134 Value p0 = (pat == 0) ? constantIndex(builder, loc, 0) 1135 : codegen.pidxs[tensor][topSort[pat - 1]]; 1136 codegen.pidxs[tensor][idx] = genLoad(codegen, builder, loc, ptr, p0); 1137 Value p1 = builder.create<arith::AddIOp>(loc, p0, one); 1138 codegen.highs[tensor][idx] = genLoad(codegen, builder, loc, ptr, p1); 1139 } else { 1140 // Dense index still in play. 1141 needsUniv = true; 1142 } 1143 } 1144 } 1145 1146 // Initialize the universal dense index. 1147 codegen.loops[idx] = constantIndex(builder, loc, 0); 1148 return needsUniv; 1149 } 1150 1151 /// Returns vectorization strategy. Any implicit inner loop in the Linalg 1152 /// operation is a candidate. Whether it is actually converted to SIMD code 1153 /// depends on the requested strategy. 1154 static bool isVectorFor(CodeGen &codegen, bool isInner, bool isReduction, 1155 bool isSparse) { 1156 // Reject vectorization of sparse output, unless innermost is reduction. 1157 if (codegen.sparseOut && !isReduction) 1158 return false; 1159 // Inspect strategy. 1160 switch (codegen.options.vectorizationStrategy) { 1161 case SparseVectorizationStrategy::kNone: 1162 return false; 1163 case SparseVectorizationStrategy::kDenseInnerLoop: 1164 return isInner && !isSparse; 1165 case SparseVectorizationStrategy::kAnyStorageInnerLoop: 1166 return isInner; 1167 } 1168 llvm_unreachable("unexpected vectorization strategy"); 1169 } 1170 1171 /// Returns parallelization strategy. Any implicit loop in the Linalg operation 1172 /// that is marked "parallel" is a candidate. Whether it is actually converted 1173 /// to a parallel operation depends on the requested strategy. 1174 static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction, 1175 bool isSparse, bool isVector) { 1176 // Reject parallelization of sparse output. 1177 if (codegen.sparseOut) 1178 return false; 1179 // Inspect strategy. 1180 switch (codegen.options.parallelizationStrategy) { 1181 case SparseParallelizationStrategy::kNone: 1182 return false; 1183 case SparseParallelizationStrategy::kDenseOuterLoop: 1184 return isOuter && !isSparse && !isReduction && !isVector; 1185 case SparseParallelizationStrategy::kAnyStorageOuterLoop: 1186 return isOuter && !isReduction && !isVector; 1187 case SparseParallelizationStrategy::kDenseAnyLoop: 1188 return !isSparse && !isReduction && !isVector; 1189 case SparseParallelizationStrategy::kAnyStorageAnyLoop: 1190 return !isReduction && !isVector; 1191 } 1192 llvm_unreachable("unexpected parallelization strategy"); 1193 } 1194 1195 /// Checks unit stride for dense tensors. The iteration graph may have ignored 1196 /// dense access patterns in order to avoid cycles (sparse access patterns are 1197 /// always placed innermost), but that means dense access has become strided. 1198 /// This prevents effective vectorization. 1199 static bool denseUnitStrides(Merger &merger, linalg::GenericOp op, 1200 unsigned idx) { 1201 for (OpOperand *t : op.getInputAndOutputOperands()) { 1202 if (!getSparseTensorEncoding(t->get().getType())) { 1203 auto map = op.getTiedIndexingMap(t); 1204 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 1205 AffineExpr a = map.getResult(d); 1206 // Report non-unit stride if innermost index appears at an outer 1207 // dimension (true non-unit stride) or if the innermost index appears 1208 // in a compound subscript in the innermost dimension. Even if the 1209 // latter is unit stride, it does not play well with scatter/gather. 1210 // TODO: accept unit stride affine innermost like a[i,j+k+1]? 1211 if (a.isFunctionOfDim(idx) && 1212 ((d != rank - 1) || (a.getKind() != AffineExprKind::DimId))) 1213 return false; 1214 } 1215 } 1216 } 1217 return true; 1218 } 1219 1220 /// Generates a for-loop on a single index. 1221 static Operation *genFor(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1222 linalg::GenericOp op, bool isOuter, bool isInner, 1223 unsigned idx, BitVector &indices) { 1224 unsigned fb = indices.find_first(); 1225 unsigned tensor = merger.tensor(fb); 1226 assert(idx == merger.index(fb)); 1227 auto iteratorTypes = op.iterator_types().getValue(); 1228 bool isReduction = isReductionIterator(iteratorTypes[idx]); 1229 bool isSparse = merger.isDim(fb, Dim::kSparse); 1230 bool isVector = isVectorFor(codegen, isInner, isReduction, isSparse) && 1231 denseUnitStrides(merger, op, idx); 1232 bool isParallel = 1233 isParallelFor(codegen, isOuter, isReduction, isSparse, isVector); 1234 1235 // Prepare vector length. 1236 if (isVector) 1237 codegen.curVecLength = codegen.options.vectorLength; 1238 1239 // Loop bounds and increment. 1240 Location loc = op.getLoc(); 1241 Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx]; 1242 Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx]; 1243 Value step = constantIndex(builder, loc, codegen.curVecLength); 1244 if (isVector && codegen.options.enableVLAVectorization) { 1245 Value vscale = builder.create<vector::VectorScaleOp>( 1246 loc, IndexType::get(builder.getContext())); 1247 step = builder.create<arith::MulIOp>(loc, vscale, step); 1248 } 1249 1250 // Emit a parallel loop. 1251 if (isParallel) { 1252 assert(!isVector); 1253 scf::ParallelOp parOp = builder.create<scf::ParallelOp>(loc, lo, hi, step); 1254 if (isSparse) 1255 codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0]; 1256 else 1257 codegen.loops[idx] = parOp.getInductionVars()[0]; 1258 builder.setInsertionPointToStart(parOp.getBody()); 1259 return parOp; 1260 } 1261 1262 // Emit a sequential or vector loop. 1263 SmallVector<Value, 4> operands; 1264 if (codegen.redVal) { 1265 // In a vector loop, bring reduction into SIMD form, if not already. 1266 if (isVector && !codegen.redVal.getType().isa<VectorType>()) { 1267 VectorType vtp = vectorType(codegen, codegen.redVal.getType()); 1268 Value vred = genVectorReducInit(codegen, builder, loc, vtp); 1269 updateReduc(merger, codegen, vred); 1270 } 1271 operands.push_back(codegen.redVal); 1272 } 1273 if (codegen.expValues) 1274 operands.push_back(codegen.expCount); 1275 scf::ForOp forOp = builder.create<scf::ForOp>(loc, lo, hi, step, operands); 1276 if (codegen.redVal) 1277 updateReduc(merger, codegen, forOp.getRegionIterArgs().front()); 1278 if (codegen.expValues) 1279 codegen.expCount = forOp.getRegionIterArgs().back(); 1280 // Assign induction variable to sparse or dense index. 1281 Value iv = forOp.getInductionVar(); 1282 if (isSparse) 1283 codegen.pidxs[tensor][idx] = iv; 1284 else 1285 codegen.loops[idx] = iv; 1286 builder.setInsertionPointToStart(forOp.getBody()); 1287 // Share vector iteration mask between all subsequent loads/stores. 1288 if (isVector) 1289 codegen.curVecMask = genVectorMask(codegen, builder, iv, lo, hi, step); 1290 return forOp; 1291 } 1292 1293 /// Emit a while-loop for co-iteration over multiple indices. 1294 static Operation *genWhile(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1295 linalg::GenericOp op, unsigned idx, bool needsUniv, 1296 BitVector &indices) { 1297 SmallVector<Type, 4> types; 1298 SmallVector<Value, 4> operands; 1299 // Construct the while-loop with a parameter for each index. 1300 Type indexType = builder.getIndexType(); 1301 for (unsigned b = 0, be = indices.size(); b < be; b++) { 1302 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1303 unsigned tensor = merger.tensor(b); 1304 assert(idx == merger.index(b)); 1305 types.push_back(indexType); 1306 operands.push_back(codegen.pidxs[tensor][idx]); 1307 } 1308 } 1309 if (codegen.redVal) { 1310 types.push_back(codegen.redVal.getType()); 1311 operands.push_back(codegen.redVal); 1312 } 1313 if (codegen.expValues) { 1314 types.push_back(indexType); 1315 operands.push_back(codegen.expCount); 1316 } 1317 if (needsUniv) { 1318 types.push_back(indexType); 1319 operands.push_back(codegen.loops[idx]); 1320 } 1321 assert(types.size() == operands.size()); 1322 Location loc = op.getLoc(); 1323 scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands); 1324 1325 SmallVector<Location> locs(types.size(), loc); 1326 Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs); 1327 Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs); 1328 1329 // Build the "before" region, which effectively consists 1330 // of a conjunction of "i < upper" tests on all induction. 1331 builder.setInsertionPointToStart(&whileOp.getBefore().front()); 1332 Value cond; 1333 unsigned o = 0; 1334 for (unsigned b = 0, be = indices.size(); b < be; b++) { 1335 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1336 unsigned tensor = merger.tensor(b); 1337 assert(idx == merger.index(b)); 1338 Value op1 = before->getArgument(o); 1339 Value op2 = codegen.highs[tensor][idx]; 1340 Value opc = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 1341 op1, op2); 1342 cond = cond ? builder.create<arith::AndIOp>(loc, cond, opc) : opc; 1343 codegen.pidxs[tensor][idx] = after->getArgument(o++); 1344 } 1345 } 1346 if (codegen.redVal) 1347 updateReduc(merger, codegen, after->getArgument(o++)); 1348 if (codegen.expValues) 1349 codegen.expCount = after->getArgument(o++); 1350 if (needsUniv) 1351 codegen.loops[idx] = after->getArgument(o++); 1352 assert(o == operands.size()); 1353 builder.create<scf::ConditionOp>(loc, cond, before->getArguments()); 1354 builder.setInsertionPointToStart(&whileOp.getAfter().front()); 1355 return whileOp; 1356 } 1357 1358 /// Generates a for-loop or a while-loop, depending on whether it implements 1359 /// singleton iteration or co-iteration over the given conjunction. 1360 static Operation *genLoop(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1361 linalg::GenericOp op, std::vector<unsigned> &topSort, 1362 unsigned at, bool needsUniv, BitVector &indices) { 1363 unsigned idx = topSort[at]; 1364 if (indices.count() == 1) { 1365 bool isOuter = at == 0; 1366 bool isInner = at == topSort.size() - 1; 1367 return genFor(merger, codegen, builder, op, isOuter, isInner, idx, indices); 1368 } 1369 return genWhile(merger, codegen, builder, op, idx, needsUniv, indices); 1370 } 1371 1372 /// Generates the local variables for this loop, consisting of the sparse 1373 /// indices, restored universal dense index, and dense positions. 1374 static void genLocals(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1375 linalg::GenericOp op, std::vector<unsigned> &topSort, 1376 unsigned at, bool needsUniv, BitVector &locals) { 1377 Location loc = op.getLoc(); 1378 unsigned idx = topSort[at]; 1379 1380 // Initialize sparse indices. 1381 Value min; 1382 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1383 if (locals[b] && merger.isDim(b, Dim::kSparse)) { 1384 unsigned tensor = merger.tensor(b); 1385 assert(idx == merger.index(b)); 1386 Value ptr = codegen.indices[tensor][idx]; 1387 Value s = codegen.pidxs[tensor][idx]; 1388 Value load = genLoad(codegen, builder, loc, ptr, s); 1389 codegen.idxs[tensor][idx] = load; 1390 if (!needsUniv) { 1391 if (min) { 1392 Value cmp = builder.create<arith::CmpIOp>( 1393 loc, arith::CmpIPredicate::ult, load, min); 1394 min = builder.create<arith::SelectOp>(loc, cmp, load, min); 1395 } else { 1396 min = load; 1397 } 1398 } 1399 } 1400 } 1401 1402 // Merge dense universal index over minimum. 1403 if (min) { 1404 assert(!needsUniv); 1405 codegen.loops[idx] = min; 1406 } 1407 1408 // Initialize dense positions. Note that we generate dense indices of the 1409 // output tensor unconditionally, since they may not appear in the lattice, 1410 // but may be needed for linearized codegen. 1411 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1412 if ((locals[b] || merger.isOutTensor(b, idx)) && 1413 merger.isDim(b, Dim::kDense)) { 1414 unsigned tensor = merger.tensor(b); 1415 assert(idx == merger.index(b)); 1416 unsigned pat = at; 1417 for (; pat != 0; pat--) 1418 if (codegen.pidxs[tensor][topSort[pat - 1]]) 1419 break; 1420 Value p = (pat == 0) ? constantIndex(builder, loc, 0) 1421 : codegen.pidxs[tensor][topSort[pat - 1]]; 1422 codegen.pidxs[tensor][idx] = genAddress( 1423 codegen, builder, loc, codegen.sizes[idx], p, codegen.loops[idx]); 1424 } 1425 } 1426 1427 // Move the insertion indices in lexicographic index order. During access 1428 // pattern expansion, we can skip setting the innermost dimension. 1429 if (codegen.sparseOut && !codegen.expValues) { 1430 Value pos = constantIndex(builder, loc, at); 1431 builder.create<memref::StoreOp>(loc, codegen.loops[idx], codegen.lexIdx, 1432 pos); 1433 } 1434 } 1435 1436 /// Generates the induction structure for a while-loop. 1437 static void genWhileInduction(Merger &merger, CodeGen &codegen, 1438 OpBuilder &builder, linalg::GenericOp op, 1439 unsigned idx, bool needsUniv, 1440 BitVector &induction, scf::WhileOp whileOp) { 1441 Location loc = op.getLoc(); 1442 // Finalize each else branch of all if statements. 1443 if (codegen.redVal || codegen.expValues) { 1444 while (auto ifOp = dyn_cast_or_null<scf::IfOp>( 1445 builder.getInsertionBlock()->getParentOp())) { 1446 unsigned y = 0; 1447 SmallVector<Value, 4> yields; 1448 if (codegen.redVal) { 1449 yields.push_back(codegen.redVal); 1450 updateReduc(merger, codegen, ifOp.getResult(y++)); 1451 } 1452 if (codegen.expValues) { 1453 yields.push_back(codegen.expCount); 1454 codegen.expCount = ifOp->getResult(y++); 1455 } 1456 assert(y == yields.size()); 1457 builder.create<scf::YieldOp>(loc, yields); 1458 builder.setInsertionPointAfter(ifOp); 1459 } 1460 } 1461 builder.setInsertionPointToEnd(&whileOp.getAfter().front()); 1462 // Finalize the induction. Note that the induction could be performed 1463 // in the individual if-branches to avoid re-evaluating the conditions. 1464 // However, that would result in a rather elaborate forest of yield 1465 // instructions during code generation. Moreover, performing the induction 1466 // after the if-statements more closely resembles code generated by TACO. 1467 unsigned o = 0; 1468 SmallVector<Value, 4> operands; 1469 Value one = constantIndex(builder, loc, 1); 1470 for (unsigned b = 0, be = induction.size(); b < be; b++) { 1471 if (induction[b] && merger.isDim(b, Dim::kSparse)) { 1472 unsigned tensor = merger.tensor(b); 1473 assert(idx == merger.index(b)); 1474 Value op1 = codegen.idxs[tensor][idx]; 1475 Value op2 = codegen.loops[idx]; 1476 Value op3 = codegen.pidxs[tensor][idx]; 1477 Value cmp = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 1478 op1, op2); 1479 Value add = builder.create<arith::AddIOp>(loc, op3, one); 1480 operands.push_back(builder.create<arith::SelectOp>(loc, cmp, add, op3)); 1481 codegen.pidxs[tensor][idx] = whileOp->getResult(o++); 1482 } 1483 } 1484 if (codegen.redVal) { 1485 operands.push_back(codegen.redVal); 1486 updateReduc(merger, codegen, whileOp->getResult(o++)); 1487 } 1488 if (codegen.expValues) { 1489 operands.push_back(codegen.expCount); 1490 codegen.expCount = whileOp->getResult(o++); 1491 } 1492 if (needsUniv) { 1493 operands.push_back( 1494 builder.create<arith::AddIOp>(loc, codegen.loops[idx], one)); 1495 codegen.loops[idx] = whileOp->getResult(o++); 1496 } 1497 assert(o == operands.size()); 1498 builder.create<scf::YieldOp>(loc, operands); 1499 builder.setInsertionPointAfter(whileOp); 1500 } 1501 1502 /// Generates the induction structure for a for-loop. 1503 static void genForInduction(Merger &merger, CodeGen &codegen, 1504 OpBuilder &builder, linalg::GenericOp op, 1505 Operation *loop) { 1506 Location loc = op.getLoc(); 1507 unsigned o = 0; 1508 SmallVector<Value, 4> operands; 1509 if (codegen.redVal) { 1510 operands.push_back(codegen.redVal); 1511 updateReduc(merger, codegen, loop->getResult(o++)); 1512 } 1513 if (codegen.expValues) { 1514 operands.push_back(codegen.expCount); 1515 codegen.expCount = loop->getResult(o++); 1516 } 1517 assert(o == operands.size()); 1518 if (o > 0) 1519 builder.create<scf::YieldOp>(loc, operands); 1520 builder.setInsertionPointAfter(loop); 1521 } 1522 1523 /// Generates a single if-statement within a while-loop. 1524 static scf::IfOp genIf(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1525 linalg::GenericOp op, unsigned idx, 1526 BitVector &conditions) { 1527 Location loc = op.getLoc(); 1528 SmallVector<Type, 4> types; 1529 Value cond; 1530 for (unsigned b = 0, be = conditions.size(); b < be; b++) { 1531 if (conditions[b]) { 1532 unsigned tensor = merger.tensor(b); 1533 assert(idx == merger.index(b)); 1534 Value clause; 1535 if (merger.isDim(b, Dim::kSparse)) { 1536 Value op1 = codegen.idxs[tensor][idx]; 1537 Value op2 = codegen.loops[idx]; 1538 clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 1539 op1, op2); 1540 } else { 1541 clause = constantI1(builder, loc, true); 1542 } 1543 cond = cond ? builder.create<arith::AndIOp>(loc, cond, clause) : clause; 1544 } 1545 } 1546 if (codegen.redVal) 1547 types.push_back(codegen.redVal.getType()); 1548 if (codegen.expValues) 1549 types.push_back(builder.getIndexType()); 1550 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true); 1551 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 1552 return ifOp; 1553 } 1554 1555 /// Generates end of true branch of if-statement within a while-loop. 1556 static void endIf(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1557 linalg::GenericOp op, scf::IfOp ifOp, Operation *loop, 1558 Value redInput, Value cntInput) { 1559 SmallVector<Value, 4> operands; 1560 if (codegen.redVal) { 1561 operands.push_back(codegen.redVal); 1562 updateReduc(merger, codegen, redInput); 1563 } 1564 if (codegen.expValues) { 1565 operands.push_back(codegen.expCount); 1566 codegen.expCount = cntInput; 1567 } 1568 if (!operands.empty()) 1569 builder.create<scf::YieldOp>(op.getLoc(), operands); 1570 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 1571 } 1572 1573 //===----------------------------------------------------------------------===// 1574 // Sparse compiler synthesis methods (loop sequence). 1575 //===----------------------------------------------------------------------===// 1576 1577 /// Starts a loop sequence at given level. Returns true if 1578 /// the universal loop index must be maintained at this level. 1579 static bool startLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1580 linalg::GenericOp op, std::vector<unsigned> &topSort, 1581 unsigned exp, unsigned at, unsigned idx, unsigned ldx, 1582 unsigned lts) { 1583 assert(codegen.curVecLength == 1); 1584 assert(!codegen.loops[idx]); 1585 // Emit invariants at this loop sequence level. 1586 genInvariants(merger, codegen, builder, op, exp, ldx, /*atStart=*/true); 1587 // Emit access pattern expansion for sparse tensor output. 1588 genExpansion(merger, codegen, builder, op, at, /*atStart=*/true); 1589 // Emit further intitialization at this loop sequence level. 1590 unsigned l0 = merger.set(lts)[0]; 1591 bool needsUniv = 1592 genInit(merger, codegen, builder, op, topSort, at, merger.lat(l0).bits); 1593 // Maintain the universal index only if it is actually 1594 // consumed by a subsequent lattice point. 1595 if (needsUniv) { 1596 unsigned lsize = merger.set(lts).size(); 1597 for (unsigned i = 1; i < lsize; i++) { 1598 unsigned li = merger.set(lts)[i]; 1599 if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) 1600 return true; 1601 } 1602 } 1603 return false; 1604 } 1605 1606 /// Starts a single loop in current sequence. 1607 static Operation *startLoop(Merger &merger, CodeGen &codegen, 1608 OpBuilder &builder, linalg::GenericOp op, 1609 std::vector<unsigned> &topSort, unsigned at, 1610 unsigned li, bool needsUniv) { 1611 assert(codegen.curVecLength == 1); 1612 // Emit the for/while-loop control. 1613 Operation *loop = genLoop(merger, codegen, builder, op, topSort, at, 1614 needsUniv, merger.lat(li).simple); 1615 // Emit the locals for this loop. 1616 genLocals(merger, codegen, builder, op, topSort, at, needsUniv, 1617 merger.lat(li).bits); 1618 return loop; 1619 } 1620 1621 /// Ends a single loop in current sequence. Returns new values for needsUniv. 1622 static bool endLoop(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1623 linalg::GenericOp op, Operation *loop, unsigned idx, 1624 unsigned li, bool needsUniv) { 1625 codegen.curVecLength = 1; 1626 // End a while-loop. 1627 if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) { 1628 genWhileInduction(merger, codegen, builder, op, idx, needsUniv, 1629 merger.lat(li).bits, whileOp); 1630 return needsUniv; 1631 } 1632 // End a for-loop. 1633 genForInduction(merger, codegen, builder, op, loop); 1634 return false; 1635 } 1636 1637 /// Ends a loop sequence at given level. 1638 static void endLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1639 linalg::GenericOp op, unsigned exp, unsigned at, 1640 unsigned idx, unsigned ldx) { 1641 assert(codegen.curVecLength == 1); 1642 codegen.loops[idx] = Value(); 1643 // Bring a pending reduction back from SIMD form when sequence ends. 1644 if (codegen.redVal) 1645 if (auto vtp = codegen.redVal.getType().dyn_cast<VectorType>()) 1646 updateReduc(merger, codegen, 1647 genVectorReducEnd(codegen, builder, op.getLoc(), vtp)); 1648 // Unmark bookkeeping of invariants and loop index. 1649 genInvariants(merger, codegen, builder, op, exp, ldx, /*atStart=*/false); 1650 // Finalize access pattern expansion for sparse tensor output. 1651 genExpansion(merger, codegen, builder, op, at, /*atStart=*/false); 1652 } 1653 1654 /// Recursively generates code while computing iteration lattices in order 1655 /// to manage the complexity of implementing co-iteration over unions 1656 /// and intersections of sparse iterations spaces. 1657 static void genStmt(Merger &merger, CodeGen &codegen, RewriterBase &rewriter, 1658 linalg::GenericOp op, std::vector<unsigned> &topSort, 1659 unsigned exp, unsigned at) { 1660 // At each leaf, assign remaining tensor (sub)expression to output tensor. 1661 if (at == topSort.size()) { 1662 unsigned ldx = topSort[at - 1]; 1663 Value rhs = genExp(merger, codegen, rewriter, op, exp, ldx); 1664 genTensorStore(merger, codegen, rewriter, op, exp, rhs); 1665 return; 1666 } 1667 1668 // Construct iteration lattices for current loop index, with L0 at top. 1669 unsigned idx = topSort[at]; 1670 unsigned ldx = at == 0 ? -1u : topSort[at - 1]; 1671 unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx)); 1672 1673 // Start a loop sequence. 1674 bool needsUniv = startLoopSeq(merger, codegen, rewriter, op, topSort, exp, at, 1675 idx, ldx, lts); 1676 1677 // Emit a loop for every lattice point L0 >= Li in this loop sequence. 1678 unsigned lsize = merger.set(lts).size(); 1679 for (unsigned i = 0; i < lsize; i++) { 1680 // Start a loop. 1681 unsigned li = merger.set(lts)[i]; 1682 Operation *loop = 1683 startLoop(merger, codegen, rewriter, op, topSort, at, li, needsUniv); 1684 1685 // Visit all lattices points with Li >= Lj to generate the 1686 // loop-body, possibly with if statements for coiteration. 1687 Value redInput = codegen.redVal; 1688 Value cntInput = codegen.expCount; 1689 bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr; 1690 for (unsigned j = 0; j < lsize; j++) { 1691 unsigned lj = merger.set(lts)[j]; 1692 unsigned ej = merger.lat(lj).exp; 1693 if (li == lj || merger.latGT(li, lj)) { 1694 // Recurse into body of each branch. 1695 if (isWhile) { 1696 scf::IfOp ifOp = 1697 genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple); 1698 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1699 endIf(merger, codegen, rewriter, op, ifOp, loop, redInput, cntInput); 1700 } else { 1701 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1702 } 1703 } 1704 } 1705 1706 // End a loop. 1707 needsUniv = 1708 endLoop(merger, codegen, rewriter, op, loop, idx, li, needsUniv); 1709 } 1710 1711 // End a loop sequence. 1712 endLoopSeq(merger, codegen, rewriter, op, exp, at, idx, ldx); 1713 } 1714 1715 /// Converts the result computed by the sparse kernel into the required form. 1716 static void genResult(Merger &merger, CodeGen &codegen, RewriterBase &rewriter, 1717 linalg::GenericOp op) { 1718 OpOperand *lhs = op.getOutputOperand(0); 1719 Type resType = lhs->get().getType(); 1720 if (getSparseTensorEncoding(resType)) { 1721 // The sparse tensor rematerializes from the original sparse tensor's 1722 // underlying sparse storage format. 1723 rewriter.replaceOpWithNewOp<LoadOp>(op, resType, lhs->get(), 1724 codegen.sparseOut == lhs); 1725 } else { 1726 // To rematerialize an non-annotated tensor, simply load it 1727 // from the bufferized value. 1728 Value val = codegen.buffers.back(); // value array 1729 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val); 1730 } 1731 } 1732 1733 //===----------------------------------------------------------------------===// 1734 // Sparse compiler rewriting methods. 1735 //===----------------------------------------------------------------------===// 1736 1737 namespace { 1738 1739 /// Sparse rewriting rule for generic Lingalg operation. 1740 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> { 1741 public: 1742 GenericOpSparsifier(MLIRContext *context, SparsificationOptions o) 1743 : OpRewritePattern<linalg::GenericOp>(context), options(o) {} 1744 1745 LogicalResult matchAndRewrite(linalg::GenericOp op, 1746 PatternRewriter &rewriter) const override { 1747 // Detects sparse annotations and translate the per-dimension sparsity 1748 // information for all tensors to loop indices in the kernel. 1749 assert(op.getNumOutputs() == 1); 1750 unsigned numTensors = op.getNumInputsAndOutputs(); 1751 unsigned numLoops = op.iterator_types().getValue().size(); 1752 Merger merger(numTensors, numLoops); 1753 if (!findSparseAnnotations(merger, op)) 1754 return failure(); 1755 1756 // Computes a topologically sorted iteration graph to ensure tensors 1757 // are visited in natural index order. Gradually relaxes the considered 1758 // constraints until an acyclic iteration graph results, such that sparse 1759 // code generation can proceed. As a last resort, an attempt is made 1760 // to resolve cycles by inserting a conversion. 1761 std::vector<unsigned> topSort; 1762 if (!computeIterationGraph(merger, op, topSort, SortMask::kIncludeAll) && 1763 !computeIterationGraph(merger, op, topSort, SortMask::kIncludeUndef) && 1764 !computeIterationGraph(merger, op, topSort, SortMask::kIncludeDense) && 1765 !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly)) { 1766 return resolveCycle(merger, rewriter, op); 1767 } 1768 1769 // Builds the tensor expression for the Linalg operation in SSA form. 1770 Optional<unsigned> optExp = merger.buildTensorExpFromLinalg(op); 1771 if (!optExp.hasValue()) 1772 return failure(); 1773 unsigned exp = optExp.getValue(); 1774 1775 // Rejects an inadmissable tensor expression. 1776 OpOperand *sparseOut = nullptr; 1777 unsigned outerParNest = 0; 1778 if (!isAdmissableTensorExp(merger, op, topSort, exp, &sparseOut, 1779 outerParNest)) 1780 return failure(); 1781 1782 // Recursively generates code. 1783 merger.setHasSparseOut(sparseOut != nullptr); 1784 CodeGen codegen(options, numTensors, numLoops, sparseOut, outerParNest); 1785 genBuffers(merger, codegen, rewriter, op); 1786 genStmt(merger, codegen, rewriter, op, topSort, exp, 0); 1787 genResult(merger, codegen, rewriter, op); 1788 return success(); 1789 } 1790 1791 private: 1792 // Last resort cycle resolution. 1793 LogicalResult resolveCycle(Merger &merger, PatternRewriter &rewriter, 1794 linalg::GenericOp op) const { 1795 // Compute topological sort while leaving out every 1796 // sparse input tensor in succession until an acylic 1797 // iteration graph results. 1798 std::vector<unsigned> topSort; 1799 for (OpOperand *t : op.getInputOperands()) { 1800 unsigned tensor = t->getOperandNumber(); 1801 Value tval = t->get(); 1802 auto srcEnc = getSparseTensorEncoding(tval.getType()); 1803 if (!srcEnc || 1804 !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly, t)) 1805 continue; 1806 // Found an input tensor that resolves the cycle by inserting a 1807 // conversion into a sparse tensor that adheres to the iteration 1808 // graph order. Also releases the temporary sparse tensor. 1809 // 1810 // TODO: investigate fusing the conversion with computation, 1811 // especially if it is a direct yield! 1812 // 1813 auto srcTp = tval.getType().cast<RankedTensorType>(); 1814 auto dstEnc = SparseTensorEncodingAttr::get( 1815 op->getContext(), srcEnc.getDimLevelType(), 1816 permute(getContext(), op.getTiedIndexingMap(t), topSort), // new order 1817 srcEnc.getPointerBitWidth(), srcEnc.getIndexBitWidth()); 1818 auto dstTp = RankedTensorType::get(srcTp.getShape(), 1819 srcTp.getElementType(), dstEnc); 1820 auto convert = rewriter.create<ConvertOp>(tval.getLoc(), dstTp, tval); 1821 op->setOperand(tensor, convert); 1822 rewriter.setInsertionPointAfter(op); 1823 rewriter.create<ReleaseOp>(tval.getLoc(), convert); 1824 return success(); 1825 } 1826 // Cannot be resolved with a single conversion. 1827 // TODO: convert more than one? 1828 return failure(); 1829 } 1830 1831 /// Options to control sparse code generation. 1832 SparsificationOptions options; 1833 }; 1834 1835 /// Sparse rewriting rule for expand shape operator. 1836 struct ExpandShapeRewriter : public OpRewritePattern<tensor::ExpandShapeOp> { 1837 public: 1838 using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern; 1839 1840 LogicalResult matchAndRewrite(tensor::ExpandShapeOp op, 1841 PatternRewriter &rewriter) const override { 1842 Location loc = op->getLoc(); 1843 auto encDst = getSparseTensorEncoding(op.result().getType()); 1844 auto encSrc = getSparseTensorEncoding(op.src().getType()); 1845 // Since a pure dense expansion is very cheap (change of view), for 1846 // sparse2dense or dense2sparse, we can simply unfuse a sparse 1847 // conversion from the actual expansion operation itself. 1848 if (encDst && encSrc) { 1849 return failure(); // TODO: implement sparse2sparse 1850 } else if (encSrc) { 1851 RankedTensorType rtp = op.src().getType().cast<RankedTensorType>(); 1852 auto denseTp = 1853 RankedTensorType::get(rtp.getShape(), rtp.getElementType()); 1854 auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.src()); 1855 op->setOperand(0, convert); 1856 return success(); 1857 } else if (encDst) { 1858 RankedTensorType rtp = op.result().getType().cast<RankedTensorType>(); 1859 auto denseTp = 1860 RankedTensorType::get(rtp.getShape(), rtp.getElementType()); 1861 auto reshape = rewriter.create<tensor::ExpandShapeOp>( 1862 loc, denseTp, op.src(), op.getReassociation()); 1863 Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape); 1864 rewriter.replaceOp(op, convert); 1865 return success(); 1866 } 1867 return failure(); 1868 } 1869 }; 1870 1871 /// Sparse rewriting rule for collapse shape operator. 1872 struct CollapseShapeRewriter 1873 : public OpRewritePattern<tensor::CollapseShapeOp> { 1874 public: 1875 using OpRewritePattern<tensor::CollapseShapeOp>::OpRewritePattern; 1876 1877 LogicalResult matchAndRewrite(tensor::CollapseShapeOp op, 1878 PatternRewriter &rewriter) const override { 1879 Location loc = op->getLoc(); 1880 auto encDst = getSparseTensorEncoding(op.result().getType()); 1881 auto encSrc = getSparseTensorEncoding(op.src().getType()); 1882 // Since a pure dense collapse is very cheap (change of view), for 1883 // sparse2dense or dense2sparse, we can simply unfuse a sparse 1884 // conversion from the actual collapse operation itself. 1885 if (encDst && encSrc) { 1886 return failure(); // TODO: implement sparse2sparse 1887 } else if (encSrc) { 1888 RankedTensorType rtp = op.src().getType().cast<RankedTensorType>(); 1889 auto denseTp = 1890 RankedTensorType::get(rtp.getShape(), rtp.getElementType()); 1891 auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.src()); 1892 op->setOperand(0, convert); 1893 return success(); 1894 } else if (encDst) { 1895 RankedTensorType rtp = op.result().getType().cast<RankedTensorType>(); 1896 auto denseTp = 1897 RankedTensorType::get(rtp.getShape(), rtp.getElementType()); 1898 auto reshape = rewriter.create<tensor::CollapseShapeOp>( 1899 loc, denseTp, op.src(), op.getReassociation()); 1900 Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape); 1901 rewriter.replaceOp(op, convert); 1902 return success(); 1903 } 1904 return failure(); 1905 } 1906 }; 1907 1908 } // namespace 1909 1910 /// Populates the given patterns list with rewriting rules required for 1911 /// the sparsification of linear algebra operations. 1912 void mlir::populateSparsificationPatterns( 1913 RewritePatternSet &patterns, const SparsificationOptions &options) { 1914 patterns.add<GenericOpSparsifier>(patterns.getContext(), options); 1915 patterns.add<ExpandShapeRewriter, CollapseShapeRewriter>( 1916 patterns.getContext()); 1917 } 1918