1 //===- Sparsification.cpp - Implementation of sparsification --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements converting sparse tensor types to actual sparse code. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "CodegenUtils.h" 14 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 18 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 19 #include "mlir/Dialect/Func/IR/FuncOps.h" 20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 21 #include "mlir/Dialect/Linalg/IR/Linalg.h" 22 #include "mlir/Dialect/Linalg/Utils/Utils.h" 23 #include "mlir/Dialect/MemRef/IR/MemRef.h" 24 #include "mlir/Dialect/SCF/IR/SCF.h" 25 #include "mlir/Dialect/SCF/Transforms/Transforms.h" 26 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 27 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 28 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 29 #include "mlir/Dialect/Tensor/IR/Tensor.h" 30 #include "mlir/Dialect/Vector/IR/VectorOps.h" 31 #include "mlir/IR/Matchers.h" 32 #include "mlir/IR/TensorEncoding.h" 33 #include "llvm/ADT/SmallBitVector.h" 34 35 using namespace mlir; 36 using namespace mlir::sparse_tensor; 37 38 //===----------------------------------------------------------------------===// 39 // Declarations of data structures. 40 //===----------------------------------------------------------------------===// 41 42 namespace { 43 44 // Iteration graph sorting. 45 enum SortMask { 46 kSparseOnly = 0x0, 47 kIncludeDense = 0x1, 48 kIncludeUndef = 0x2, 49 kIncludeAll = 0x3 50 }; 51 52 // Reduction kinds. 53 enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor }; 54 55 // Code generation. 56 struct CodeGen { 57 CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops, 58 OpOperand *op, unsigned nest) 59 : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors), 60 pointers(numTensors, std::vector<Value>(numLoops)), 61 indices(numTensors, std::vector<Value>(numLoops)), 62 highs(numTensors, std::vector<Value>(numLoops)), 63 pidxs(numTensors, std::vector<Value>(numLoops)), 64 idxs(numTensors, std::vector<Value>(numLoops)), redVal(), sparseOut(op), 65 outerParNest(nest), lexIdx(), lexVal(), expValues(), expFilled(), 66 expAdded(), expCount(), curVecMask() {} 67 /// Sparsification options. 68 SparsificationOptions options; 69 /// Universal dense indices and upper bounds (by index). The loops array 70 /// is updated with the value of the universal dense index in the current 71 /// loop. The sizes array is set once with the inferred dimension sizes. 72 std::vector<Value> loops; 73 std::vector<Value> sizes; 74 /// Buffers for storing dense and sparse numerical values (by tensor). 75 /// This array is set once during bufferization of all tensors. 76 std::vector<Value> buffers; 77 /// Sparse storage schemes (1-D): pointers and indices (by tensor and index). 78 /// This array is set once during bufferization of all sparse tensors. 79 std::vector<std::vector<Value>> pointers; 80 std::vector<std::vector<Value>> indices; 81 /// Sparse iteration information (by tensor and index). These arrays 82 /// are updated to remain current within the current loop. 83 std::vector<std::vector<Value>> highs; 84 std::vector<std::vector<Value>> pidxs; 85 std::vector<std::vector<Value>> idxs; 86 /// Current reduction, updated during code generation. When indices of a 87 /// reduction are exhausted, all inner loops can use a scalarized reduction. 88 unsigned redExp = -1u; 89 Value redVal; 90 Reduction redKind = kNoReduc; 91 // Sparse tensor as output. Implemented either through direct injective 92 // insertion in lexicographic index order (where indices are updated 93 // in the temporary array `lexIdx`) or through access pattern expansion 94 // in the innermost loop nest (`expValues` through `expCount`). 95 OpOperand *sparseOut; 96 unsigned outerParNest; 97 Value lexIdx; 98 Value lexVal; 99 Value expValues; 100 Value expFilled; 101 Value expAdded; 102 Value expCount; 103 // Current vector length and mask. 104 unsigned curVecLength = 1; 105 Value curVecMask; 106 }; 107 108 } // namespace 109 110 //===----------------------------------------------------------------------===// 111 // Sparse compiler analysis methods. 112 //===----------------------------------------------------------------------===// 113 114 /// Helper method to construct a permuted dimension ordering 115 /// that adheres to the given topological sort. 116 static AffineMap permute(MLIRContext *context, AffineMap m, 117 std::vector<unsigned> &topSort) { 118 unsigned sz = topSort.size(); 119 assert(m.getNumResults() == sz && "TopoSort/AffineMap size mismatch"); 120 // Construct the inverse of `m`; to avoid the asymptotic complexity 121 // of calling `m.getPermutedPosition` repeatedly. 122 SmallVector<unsigned, 4> inv(sz); 123 for (unsigned i = 0; i < sz; i++) 124 inv[i] = m.getDimPosition(i); 125 // Construct the permutation. 126 SmallVector<unsigned, 4> perm(sz); 127 for (unsigned i = 0; i < sz; i++) 128 perm[i] = inv[topSort[i]]; 129 return AffineMap::getPermutationMap(perm, context); 130 } 131 132 /// Helper method to apply dimension ordering permutation. 133 static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d) { 134 if (enc) { 135 auto order = enc.getDimOrdering(); 136 if (order) { 137 assert(order.isPermutation()); 138 return order.getDimPosition(d); 139 } 140 } 141 return d; 142 } 143 144 /// Helper method to translate dim level type to internal representation. 145 static Dim toDim(const SparseTensorEncodingAttr &enc, unsigned d) { 146 if (enc) { 147 SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d]; 148 if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed) 149 return Dim::kSparse; 150 if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton) 151 return Dim::kSingle; 152 } 153 return Dim::kDense; 154 } 155 156 /// Helper method to inspect affine expressions. Rejects cases where the 157 /// same index is used more than once. Also rejects affine expressions 158 /// that are not a direct index for annotated tensors. 159 // TODO: accept more affine cases for sparse tensors 160 static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a, Dim dim, 161 bool isDense) { 162 switch (a.getKind()) { 163 case AffineExprKind::DimId: { 164 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 165 if (!merger.isDim(tensor, idx, Dim::kUndef)) 166 return false; // used more than once 167 merger.setDim(tensor, idx, dim); 168 return true; 169 } 170 case AffineExprKind::Add: 171 case AffineExprKind::Mul: { 172 if (!isDense) 173 return false; 174 auto binOp = a.cast<AffineBinaryOpExpr>(); 175 return findAffine(merger, tensor, binOp.getLHS(), dim, isDense) && 176 findAffine(merger, tensor, binOp.getRHS(), dim, isDense); 177 } 178 case AffineExprKind::Constant: 179 return isDense; 180 default: 181 return false; 182 } 183 } 184 185 /// Helper method to inspect sparse encodings in the tensor types. 186 /// Fills the per-dimension sparsity information for all tensors. 187 /// Returns true if the sparse annotations and affine subscript 188 /// expressions of all tensors are admissable. Returns false if 189 /// no annotations are found or inadmissable constructs occur. 190 static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) { 191 bool annotated = false; 192 for (OpOperand *t : op.getInputAndOutputOperands()) { 193 auto map = op.getTiedIndexingMap(t); 194 auto enc = getSparseTensorEncoding(t->get().getType()); 195 if (enc) 196 annotated = true; 197 assert(map.getNumResults() == op.getRank(t)); 198 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 199 unsigned tensor = t->getOperandNumber(); 200 AffineExpr a = map.getResult(perm(enc, d)); 201 if (!findAffine(merger, tensor, a, toDim(enc, d), !enc)) 202 return false; // inadmissable affine expression 203 } 204 } 205 return annotated; 206 } 207 208 /// A DFS helper to compute a topological sort. Note that recursion is 209 /// bounded by the number of implicit loops, which is always small. 210 /// Returns false when a cycle is detected. 211 static bool topSortDFS(unsigned i, std::vector<unsigned> &visit, 212 std::vector<unsigned> &topSort, 213 std::vector<std::vector<bool>> &adjM) { 214 if (visit[i] != 0) 215 return visit[i] != 1; // 1 denotes cycle! 216 visit[i] = 1; 217 for (unsigned j = 0, e = visit.size(); j < e; j++) 218 if (adjM[i][j]) 219 if (!topSortDFS(j, visit, topSort, adjM)) 220 return false; 221 visit[i] = 2; 222 topSort.push_back(i); 223 return true; 224 } 225 226 /// Helper method to add all constraints from the indices in one affine 227 /// expression before all indices in the other affine expression. For 228 /// example i0+i1 < i2+i3+1 yields i0<i2, i0<i3, i1<i2, and i1<i3. 229 static void addAffineOrderings(std::vector<std::vector<bool>> &adjM, 230 AffineExpr a, AffineExpr b, unsigned fidx) { 231 switch (a.getKind()) { 232 case AffineExprKind::DimId: { 233 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 234 if (b) 235 addAffineOrderings(adjM, b, AffineExpr(), idx); 236 else 237 adjM[fidx][idx] = true; 238 break; 239 } 240 case AffineExprKind::Add: 241 case AffineExprKind::Mul: { 242 auto binOp = a.cast<AffineBinaryOpExpr>(); 243 addAffineOrderings(adjM, binOp.getLHS(), b, fidx); 244 addAffineOrderings(adjM, binOp.getRHS(), b, fidx); 245 break; 246 } 247 default: 248 break; 249 } 250 } 251 252 /// Computes a topologically sorted iteration graph for the linalg operation. 253 /// Ensures all tensors are visited in natural index order. This is essential 254 /// for sparse storage formats since these only support access along fixed 255 /// dimensions. Even for dense storage formats, however, the natural index 256 /// order yields innermost unit-stride access with better spatial locality. 257 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op, 258 std::vector<unsigned> &topSort, unsigned mask, 259 OpOperand *skip = nullptr) { 260 // Set up an n x n from/to adjacency matrix of the iteration graph 261 // for the implicit loop indices i_0 .. i_n-1. 262 unsigned n = op.getNumLoops(); 263 std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false)); 264 265 // Iterate over the indexing maps of every tensor in the tensor expression. 266 for (OpOperand *t : op.getInputAndOutputOperands()) { 267 // Skip tensor during cycle resolution. 268 if (t == skip) 269 continue; 270 // Get map and encoding. 271 auto map = op.getTiedIndexingMap(t); 272 auto enc = getSparseTensorEncoding(t->get().getType()); 273 assert(map.getNumDims() == n); 274 // Skip dense tensor constraints when not requested. 275 if (!(mask & SortMask::kIncludeDense) && !enc) 276 continue; 277 // Each tensor expression and optional dimension ordering (row-major 278 // by default) puts an ordering constraint on the loop indices. For 279 // example, the tensor expresion A_ijk forces the ordering i < j < k 280 // on the loop indices if no explicit dimension ordering is given. 281 for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) { 282 AffineExpr f = map.getResult(perm(enc, d - 1)); 283 AffineExpr t = map.getResult(perm(enc, d)); 284 addAffineOrderings(adjM, f, t, 0); 285 } 286 // Push unrelated loops into sparse iteration space, so these 287 // will be skipped more often. 288 if (mask & SortMask::kIncludeUndef) { 289 unsigned tensor = t->getOperandNumber(); 290 for (unsigned i = 0; i < n; i++) 291 if (merger.isDim(tensor, i, Dim::kSparse)) 292 for (unsigned j = 0; j < n; j++) 293 if (merger.isDim(tensor, j, Dim::kUndef)) 294 adjM[i][j] = true; 295 } 296 } 297 298 // Topologically sort the iteration graph to determine loop order. 299 // Report failure for a cyclic iteration graph. 300 topSort.clear(); 301 topSort.reserve(n); 302 std::vector<unsigned> visit(n, 0); 303 for (unsigned i = 0; i < n; i++) 304 if (visit[i] == 0) 305 if (!topSortDFS(i, visit, topSort, adjM)) 306 return false; // cycle! 307 std::reverse(std::begin(topSort), std::end(topSort)); 308 return true; 309 } 310 311 /// Returns true if tensor materializes uninitialized into the computation. 312 static bool isMaterializing(Value val) { 313 return val.getDefiningOp<linalg::InitTensorOp>() || 314 val.getDefiningOp<bufferization::AllocTensorOp>(); 315 } 316 317 /// Returns true when the tensor expression is admissable for codegen. 318 /// Since all sparse input tensors are admissable, we just need to check 319 /// whether the out tensor in the tensor expression codegen is admissable. 320 /// Sets `sparseOut` to the tensor and `outerParNest` to the outer injective 321 /// nesting depth when a "truly dynamic" sparse tensor output occurs. 322 static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op, 323 std::vector<unsigned> &topSort, unsigned exp, 324 OpOperand **sparseOut, 325 unsigned &outerParNest) { 326 OpOperand *lhs = op.getOutputOperand(0); 327 unsigned tensor = lhs->getOperandNumber(); 328 auto enc = getSparseTensorEncoding(lhs->get().getType()); 329 // An non-annotated output tensor is assumed dense, and becomes a random 330 // access n-dim memref. Admissable since insertions cannot occur. 331 if (!enc) 332 return true; 333 // An all-dense annotated "sparse" output tensor becomes a linearized random 334 // access 1-dim memref. Also admissable since insertions cannot occur. 335 bool allDense = true; 336 auto iteratorTypes = op.iterator_types().getValue(); 337 unsigned numLoops = iteratorTypes.size(); 338 for (unsigned i = 0; i < numLoops; i++) 339 if (merger.isDim(tensor, i, Dim::kSparse)) { 340 allDense = false; 341 break; 342 } 343 if (allDense) 344 return true; 345 // A tensor expression with a sparse output tensor that changes its values 346 // but not its nonzero structure, an operation called "simply dynamic" in 347 // [Bik96,Ch9], is also admissable without special codegen. 348 if (merger.isSingleCondition(tensor, exp)) 349 return true; 350 // Accept "truly dynamic" if the output tensor materializes uninitialized 351 // into the computation and insertions occur in lexicographic index order. 352 if (isMaterializing(lhs->get())) { 353 unsigned nest = 0; 354 for (unsigned i = 0; i < numLoops; i++) { 355 if (isReductionIterator(iteratorTypes[topSort[i]])) 356 break; // terminate at first reduction 357 nest++; 358 } 359 // Determine admissable dynamic insertion situations: 360 // (1) fully injective, since there are no reductions, 361 // (2) admissable 1-d expansion in innermost dimension. 362 if (nest >= op.getRank(lhs) - 1) { 363 *sparseOut = lhs; 364 outerParNest = nest; 365 return true; 366 } 367 } 368 return false; 369 } 370 371 //===----------------------------------------------------------------------===// 372 // Sparse compiler synthesis methods (reductions). 373 //===----------------------------------------------------------------------===// 374 375 /// Maps reduction kind to vector::CombiningKind. 376 static vector::CombiningKind getCombiningKind(Reduction kind) { 377 switch (kind) { 378 case kNoReduc: 379 break; 380 case kSum: 381 return vector::CombiningKind::ADD; 382 case kProduct: 383 return vector::CombiningKind::MUL; 384 case kAnd: 385 return vector::CombiningKind::AND; 386 case kOr: 387 return vector::CombiningKind::OR; 388 case kXor: 389 return vector::CombiningKind::XOR; 390 } 391 llvm_unreachable("unknown reduction kind"); 392 } 393 394 /// Maps operation to reduction. 395 static Reduction getReduction(Kind kind) { 396 switch (kind) { 397 case Kind::kAddF: 398 case Kind::kAddC: 399 case Kind::kAddI: 400 case Kind::kSubF: 401 case Kind::kSubC: 402 case Kind::kSubI: 403 return kSum; 404 case Kind::kMulF: 405 case Kind::kMulC: 406 case Kind::kMulI: 407 return kProduct; 408 case Kind::kAndI: 409 return kAnd; 410 case Kind::kOrI: 411 return kOr; 412 case Kind::kXorI: 413 return kXor; 414 default: 415 llvm_unreachable("unexpected reduction operator"); 416 } 417 } 418 419 /// Generates an initial value for a vector reduction, following the scheme 420 /// given in Chapter 5 of "The Software Vectorization Handbook", where the 421 /// initial scalar value is correctly embedded in the vector reduction value, 422 /// and a straightforward horizontal reduction will complete the operation. 423 static Value genVectorReducInit(CodeGen &codegen, OpBuilder &builder, 424 Location loc, VectorType vtp) { 425 Value r = codegen.redVal; 426 switch (codegen.redKind) { 427 case kNoReduc: 428 break; 429 case kSum: 430 case kXor: 431 // Initialize reduction vector to: | 0 | .. | 0 | r | 432 return builder.create<vector::InsertElementOp>( 433 loc, r, constantZero(builder, loc, vtp), 434 constantIndex(builder, loc, 0)); 435 case kProduct: 436 // Initialize reduction vector to: | 1 | .. | 1 | r | 437 return builder.create<vector::InsertElementOp>( 438 loc, r, constantOne(builder, loc, vtp), constantIndex(builder, loc, 0)); 439 case kAnd: 440 case kOr: 441 // Initialize reduction vector to: | r | .. | r | r | 442 return builder.create<vector::BroadcastOp>(loc, vtp, r); 443 } 444 llvm_unreachable("unknown reduction kind"); 445 } 446 447 /// Generates final value for a vector reduction. 448 static Value genVectorReducEnd(CodeGen &codegen, OpBuilder &builder, 449 Location loc, VectorType vtp) { 450 vector::CombiningKind kind = getCombiningKind(codegen.redKind); 451 return builder.create<vector::ReductionOp>(loc, kind, codegen.redVal); 452 } 453 454 /// Updates scalarized reduction value. 455 static void updateReduc(Merger &merger, CodeGen &codegen, Value reduc) { 456 assert(codegen.redKind != kNoReduc); 457 codegen.redVal = merger.exp(codegen.redExp).val = reduc; 458 } 459 460 //===----------------------------------------------------------------------===// 461 // Sparse compiler synthesis methods (statements and expressions). 462 //===----------------------------------------------------------------------===// 463 464 /// Generates buffer for the output tensor. Note that all sparse kernels 465 /// assume that when all elements are written to (viz. x(i) = y(i) * z(i)), 466 /// the output buffer is already initialized to all zeroes and only nonzeroes 467 /// values are computed and written out. For updates (viz. x(i) += y(i) * z(i)), 468 /// only nonzeroes values are used for the updates and no assumption on the 469 /// original contents of the output buffer is necessary. 470 static Value genOutputBuffer(CodeGen &codegen, OpBuilder &builder, 471 linalg::GenericOp op, MemRefType denseTp, 472 ArrayRef<Value> args) { 473 Location loc = op.getLoc(); 474 OpOperand *lhs = op.getOutputOperand(0); 475 Value tensor = lhs->get(); 476 bool isInit = op.isInitTensor(lhs); 477 // An output tensor can simply materialize from the buffer of the tensor that 478 // appears in the outs() clause. For updates, this has the advantage that only 479 // the nonzero value are involved in the computation, keeping the operation 480 // O(nnz). In all other cases, we are forced to zero out the buffer to enforce 481 // the assumption above, which may negatively impact running complexity 482 // (viz. O(n^2 + nnz) vs. O(nnz) for matrices). 483 // TODO: use better analysis to avoid zeroing out the buffer? 484 Value init = builder.create<bufferization::ToMemrefOp>(loc, denseTp, tensor); 485 if (!isInit) { 486 Value zero = constantZero(builder, loc, denseTp.getElementType()); 487 builder.create<linalg::FillOp>(loc, ValueRange{zero}, ValueRange{init}); 488 } 489 return init; 490 } 491 492 /// Local bufferization of all dense and sparse data structures. 493 /// This code enables testing the first prototype sparse compiler. 494 // TODO: replace this with a proliferated bufferization strategy 495 static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder, 496 linalg::GenericOp op) { 497 Location loc = op.getLoc(); 498 assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1); 499 // For every tensor, find lower and upper bound on dimensions, set the 500 // same bounds on loop indices, and obtain dense or sparse buffer(s). 501 SmallVector<Value, 4> args; 502 for (OpOperand *t : op.getInputAndOutputOperands()) { 503 unsigned tensor = t->getOperandNumber(); 504 auto shape = op.getShape(t); 505 auto map = op.getTiedIndexingMap(t); 506 auto enc = getSparseTensorEncoding(t->get().getType()); 507 // Scan all dimensions of current tensor. 508 args.clear(); 509 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 510 AffineExpr a = map.getResult(perm(enc, d)); 511 if (a.getKind() != AffineExprKind::DimId) 512 continue; // compound 513 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 514 // Handle sparse storage schemes. 515 if (merger.isDim(tensor, idx, Dim::kSparse)) { 516 auto dynShape = {ShapedType::kDynamicSize}; 517 auto ptrTp = 518 MemRefType::get(dynShape, getPointerOverheadType(builder, enc)); 519 auto indTp = 520 MemRefType::get(dynShape, getIndexOverheadType(builder, enc)); 521 Value dim = constantIndex(builder, loc, d); 522 // Generate sparse primitives to obtains pointer and indices. 523 codegen.pointers[tensor][idx] = 524 builder.create<ToPointersOp>(loc, ptrTp, t->get(), dim); 525 codegen.indices[tensor][idx] = 526 builder.create<ToIndicesOp>(loc, indTp, t->get(), dim); 527 } 528 // Find upper bound in current dimension. 529 unsigned p = perm(enc, d); 530 Value up = linalg::createOrFoldDimOp(builder, loc, t->get(), p); 531 if (ShapedType::isDynamic(shape[p])) 532 args.push_back(up); 533 assert(codegen.highs[tensor][idx] == nullptr); 534 codegen.sizes[idx] = codegen.highs[tensor][idx] = up; 535 } 536 // Perform the required bufferization. Dense inputs materialize 537 // from the input tensors. Dense outputs need special handling. 538 // Sparse inputs use sparse primitives to obtain the values. 539 // We also accept in-place all-dense annotated "sparse" outputs. 540 Type elementType = getElementTypeOrSelf(t->get().getType()); 541 if (!enc) { 542 // Non-annotated dense tensors. 543 auto denseTp = MemRefType::get(shape, elementType); 544 if (tensor < op.getNumInputs()) 545 codegen.buffers[tensor] = 546 builder.create<bufferization::ToMemrefOp>(loc, denseTp, t->get()); 547 else 548 codegen.buffers[tensor] = 549 genOutputBuffer(codegen, builder, op, denseTp, args); 550 } else if (t == codegen.sparseOut) { 551 // True sparse output needs a lexIdx array. 552 Value rank = constantIndex(builder, loc, op.getRank(t)); 553 auto dynShape = {ShapedType::kDynamicSize}; 554 auto memTp = MemRefType::get(dynShape, builder.getIndexType()); 555 codegen.lexIdx = builder.create<memref::AllocaOp>(loc, memTp, rank); 556 codegen.lexVal = builder.create<memref::AllocaOp>( 557 loc, MemRefType::get({}, elementType)); 558 } else { 559 // Annotated sparse tensors. 560 auto dynShape = {ShapedType::kDynamicSize}; 561 auto sparseTp = MemRefType::get(dynShape, elementType); 562 codegen.buffers[tensor] = 563 builder.create<ToValuesOp>(loc, sparseTp, t->get()); 564 } 565 } 566 } 567 568 /// Constructs vector type. 569 static VectorType vectorType(CodeGen &codegen, Type etp) { 570 unsigned numScalableDims = codegen.options.enableVLAVectorization; 571 return VectorType::get(codegen.curVecLength, etp, numScalableDims); 572 } 573 574 /// Constructs vector type from pointer. 575 static VectorType vectorType(CodeGen &codegen, Value ptr) { 576 return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType()); 577 } 578 579 /// Constructs vector iteration mask. 580 static Value genVectorMask(CodeGen &codegen, OpBuilder &builder, Value iv, 581 Value lo, Value hi, Value step) { 582 Location loc = iv.getLoc(); 583 VectorType mtp = vectorType(codegen, builder.getI1Type()); 584 // Special case if the vector length evenly divides the trip count (for 585 // example, "for i = 0, 128, 16"). A constant all-true mask is generated 586 // so that all subsequent masked memory operations are immediately folded 587 // into unconditional memory operations. 588 IntegerAttr loInt, hiInt, stepInt; 589 if (matchPattern(lo, m_Constant(&loInt)) && 590 matchPattern(hi, m_Constant(&hiInt)) && 591 matchPattern(step, m_Constant(&stepInt))) { 592 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) 593 return builder.create<vector::BroadcastOp>( 594 loc, mtp, constantI1(builder, loc, true)); 595 } 596 // Otherwise, generate a vector mask that avoids overrunning the upperbound 597 // during vector execution. Here we rely on subsequent loop optimizations to 598 // avoid executing the mask in all iterations, for example, by splitting the 599 // loop into an unconditional vector loop and a scalar cleanup loop. 600 auto minMap = AffineMap::get( 601 /*dimCount=*/2, /*symbolCount=*/1, 602 {builder.getAffineSymbolExpr(0), 603 builder.getAffineDimExpr(0) - builder.getAffineDimExpr(1)}, 604 builder.getContext()); 605 Value end = 606 builder.createOrFold<AffineMinOp>(loc, minMap, ValueRange{hi, iv, step}); 607 return builder.create<vector::CreateMaskOp>(loc, mtp, end); 608 } 609 610 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi]. 611 static Value genVectorLoad(CodeGen &codegen, OpBuilder &builder, Value ptr, 612 ArrayRef<Value> args) { 613 Location loc = ptr.getLoc(); 614 VectorType vtp = vectorType(codegen, ptr); 615 Value pass = constantZero(builder, loc, vtp); 616 if (args.back().getType().isa<VectorType>()) { 617 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 618 Value indexVec = args.back(); 619 scalarArgs.back() = constantIndex(builder, loc, 0); 620 return builder.create<vector::GatherOp>(loc, vtp, ptr, scalarArgs, indexVec, 621 codegen.curVecMask, pass); 622 } 623 return builder.create<vector::MaskedLoadOp>(loc, vtp, ptr, args, 624 codegen.curVecMask, pass); 625 } 626 627 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs. 628 static void genVectorStore(CodeGen &codegen, OpBuilder &builder, Value rhs, 629 Value ptr, ArrayRef<Value> args) { 630 Location loc = ptr.getLoc(); 631 if (args.back().getType().isa<VectorType>()) { 632 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 633 Value indexVec = args.back(); 634 scalarArgs.back() = constantIndex(builder, loc, 0); 635 builder.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, 636 codegen.curVecMask, rhs); 637 return; 638 } 639 builder.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask, 640 rhs); 641 } 642 643 /// Generates a vectorized invariant. Here we rely on subsequent loop 644 /// optimizations to hoist the invariant broadcast out of the vector loop. 645 static Value genVectorInvariantValue(CodeGen &codegen, OpBuilder &builder, 646 Value val) { 647 VectorType vtp = vectorType(codegen, val.getType()); 648 return builder.create<vector::BroadcastOp>(val.getLoc(), vtp, val); 649 } 650 651 /// Generates an affine expression. 652 // 653 // TODO: generalize for sparse tensor subscripts 654 // 655 static Value genAffine(CodeGen &codegen, OpBuilder &builder, AffineExpr a, 656 Location loc) { 657 switch (a.getKind()) { 658 case AffineExprKind::DimId: { 659 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 660 return codegen.loops[idx]; // universal dense index 661 } 662 case AffineExprKind::Add: { 663 auto binOp = a.cast<AffineBinaryOpExpr>(); 664 return builder.create<arith::AddIOp>( 665 loc, genAffine(codegen, builder, binOp.getLHS(), loc), 666 genAffine(codegen, builder, binOp.getRHS(), loc)); 667 } 668 case AffineExprKind::Mul: { 669 auto binOp = a.cast<AffineBinaryOpExpr>(); 670 return builder.create<arith::MulIOp>( 671 loc, genAffine(codegen, builder, binOp.getLHS(), loc), 672 genAffine(codegen, builder, binOp.getRHS(), loc)); 673 } 674 case AffineExprKind::Constant: { 675 int64_t c = a.cast<AffineConstantExpr>().getValue(); 676 return constantIndex(builder, loc, c); 677 } 678 default: 679 llvm_unreachable("unexpected affine subscript"); 680 } 681 } 682 683 /// Generates index for load/store on sparse tensor. 684 static Value genIndex(CodeGen &codegen, linalg::GenericOp op, OpOperand *t) { 685 auto map = op.getTiedIndexingMap(t); 686 auto enc = getSparseTensorEncoding(t->get().getType()); 687 AffineExpr a = map.getResult(perm(enc, map.getNumResults() - 1)); 688 assert(a.getKind() == AffineExprKind::DimId); 689 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 690 return codegen.loops[idx]; 691 } 692 693 /// Generates subscript for load/store on a dense or sparse tensor. 694 static Value genSubscript(CodeGen &codegen, OpBuilder &builder, 695 linalg::GenericOp op, OpOperand *t, 696 SmallVector<Value, 4> &args) { 697 unsigned tensor = t->getOperandNumber(); 698 auto map = op.getTiedIndexingMap(t); 699 auto enc = getSparseTensorEncoding(t->get().getType()); 700 unsigned rank = map.getNumResults(); 701 if (enc) { 702 // Note that currently, all sparse subscripts are simple. 703 // TODO: accept affine too? 704 AffineExpr a = map.getResult(perm(enc, rank - 1)); 705 assert(a.getKind() == AffineExprKind::DimId); 706 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 707 assert(codegen.pidxs[tensor][idx] != nullptr); 708 args.push_back(codegen.pidxs[tensor][idx]); // position index 709 } else { 710 for (unsigned d = 0; d < rank; d++) { 711 AffineExpr a = map.getResult(perm(enc, d)); 712 args.push_back(genAffine(codegen, builder, a, op.getLoc())); 713 } 714 } 715 return codegen.buffers[tensor]; 716 } 717 718 /// Generates insertion code to implement dynamic tensor load. 719 static Value genInsertionLoad(CodeGen &codegen, OpBuilder &builder, 720 linalg::GenericOp op, OpOperand *t) { 721 Location loc = op.getLoc(); 722 // Direct lexicographic index order, tensor loads as zero. 723 if (!codegen.expValues) { 724 Type tp = getElementTypeOrSelf(t->get().getType()); 725 return constantZero(builder, loc, tp); 726 } 727 // Load from expanded access pattern. 728 Value index = genIndex(codegen, op, t); 729 return builder.create<memref::LoadOp>(loc, codegen.expValues, index); 730 } 731 732 /// Generates insertion code to implement dynamic tensor store. 733 static void genInsertionStore(CodeGen &codegen, OpBuilder &builder, 734 linalg::GenericOp op, OpOperand *t, Value rhs) { 735 Location loc = op.getLoc(); 736 // Direct insertion in lexicographic index order. 737 if (!codegen.expValues) { 738 builder.create<memref::StoreOp>(loc, rhs, codegen.lexVal); 739 builder.create<LexInsertOp>(loc, t->get(), codegen.lexIdx, codegen.lexVal); 740 return; 741 } 742 // Generates insertion code along expanded access pattern. 743 // if (!expFilled[i]) then 744 // expFilled[i] = true 745 // expAdded[inserts++] = i 746 // endif 747 // values[i] = rhs 748 Value index = genIndex(codegen, op, t); 749 Value fval = constantI1(builder, loc, false); 750 Value tval = constantI1(builder, loc, true); 751 // If statement. 752 Value filled = builder.create<memref::LoadOp>(loc, codegen.expFilled, index); 753 Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 754 filled, fval); 755 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIndexType(), cond, 756 /*else=*/true); 757 // True branch. 758 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 759 builder.create<memref::StoreOp>(loc, tval, codegen.expFilled, index); 760 builder.create<memref::StoreOp>(loc, index, codegen.expAdded, 761 codegen.expCount); 762 Value one = constantIndex(builder, loc, 1); 763 Value add = builder.create<arith::AddIOp>(loc, codegen.expCount, one); 764 builder.create<scf::YieldOp>(loc, add); 765 // False branch. 766 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 767 builder.create<scf::YieldOp>(loc, codegen.expCount); 768 builder.setInsertionPointAfter(ifOp); 769 // Value assignment. 770 codegen.expCount = ifOp.getResult(0); 771 builder.create<memref::StoreOp>(loc, rhs, codegen.expValues, index); 772 } 773 774 /// Generates a load on a dense or sparse tensor. 775 static Value genTensorLoad(Merger &merger, CodeGen &codegen, OpBuilder &builder, 776 linalg::GenericOp op, unsigned exp) { 777 // Test if the load was hoisted to a higher loop nest. 778 Value val = merger.exp(exp).val; 779 if (val) { 780 if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>()) 781 return genVectorInvariantValue(codegen, builder, val); 782 return val; 783 } 784 // Load during insertion. 785 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 786 if (t == codegen.sparseOut) 787 return genInsertionLoad(codegen, builder, op, t); 788 // Actual load. 789 SmallVector<Value, 4> args; 790 Value ptr = genSubscript(codegen, builder, op, t, args); 791 if (codegen.curVecLength > 1) 792 return genVectorLoad(codegen, builder, ptr, args); 793 return builder.create<memref::LoadOp>(op.getLoc(), ptr, args); 794 } 795 796 /// Generates a store on a dense or sparse tensor. 797 static void genTensorStore(Merger &merger, CodeGen &codegen, OpBuilder &builder, 798 linalg::GenericOp op, unsigned exp, Value rhs) { 799 Location loc = op.getLoc(); 800 // Test if this is a scalarized reduction. 801 if (codegen.redVal) { 802 if (codegen.curVecLength > 1) 803 rhs = builder.create<arith::SelectOp>(loc, codegen.curVecMask, rhs, 804 codegen.redVal); 805 updateReduc(merger, codegen, rhs); 806 return; 807 } 808 // Store during insertion. 809 OpOperand *t = op.getOutputOperand(0); 810 if (t == codegen.sparseOut) { 811 if (!rhs) { 812 // Only unary and binary are allowed to return uninitialized rhs 813 // to indicate missing output. 814 assert(merger.exp(exp).kind == kUnary || merger.exp(exp).kind == kBinary); 815 } else { 816 genInsertionStore(codegen, builder, op, t, rhs); 817 } 818 return; 819 } 820 // Actual store. 821 SmallVector<Value, 4> args; 822 Value ptr = genSubscript(codegen, builder, op, t, args); 823 if (codegen.curVecLength > 1) 824 genVectorStore(codegen, builder, rhs, ptr, args); 825 else 826 builder.create<memref::StoreOp>(loc, rhs, ptr, args); 827 } 828 829 /// Generates a pointer/index load from the sparse storage scheme. Narrower 830 /// data types need to be zero extended before casting the value into the 831 /// index type used for looping and indexing. 832 static Value genLoad(CodeGen &codegen, OpBuilder &builder, Location loc, 833 Value ptr, Value s) { 834 // See https://llvm.org/docs/GetElementPtr.html for some background on 835 // the complications described below. 836 if (codegen.curVecLength > 1) { 837 // Since the index vector is used in a subsequent gather/scatter operations, 838 // which effectively defines an unsigned pointer + signed index, we must 839 // zero extend the vector to an index width. For 8-bit and 16-bit values, 840 // an 32-bit index width suffices. For 32-bit values, zero extending the 841 // elements into 64-bit loses some performance since the 32-bit indexed 842 // gather/scatter is more efficient than the 64-bit index variant (if the 843 // negative 32-bit index space is unused, the enableSIMDIndex32 flag can 844 // preserve this performance). For 64-bit values, there is no good way 845 // to state that the indices are unsigned, with creates the potential of 846 // incorrect address calculations in the unlikely case we need such 847 // extremely large offsets. 848 Type etp = ptr.getType().cast<MemRefType>().getElementType(); 849 Value vload = genVectorLoad(codegen, builder, ptr, {s}); 850 if (!etp.isa<IndexType>()) { 851 if (etp.getIntOrFloatBitWidth() < 32) 852 vload = builder.create<arith::ExtUIOp>( 853 loc, vectorType(codegen, builder.getI32Type()), vload); 854 else if (etp.getIntOrFloatBitWidth() < 64 && 855 !codegen.options.enableSIMDIndex32) 856 vload = builder.create<arith::ExtUIOp>( 857 loc, vectorType(codegen, builder.getI64Type()), vload); 858 } 859 return vload; 860 } 861 // For the scalar case, we simply zero extend narrower indices into 64-bit 862 // values before casting to index without a performance penalty. Here too, 863 // however, indices that already are 64-bit, in theory, cannot express the 864 // full range as explained above. 865 Value load = builder.create<memref::LoadOp>(loc, ptr, s); 866 if (!load.getType().isa<IndexType>()) { 867 if (load.getType().getIntOrFloatBitWidth() < 64) 868 load = builder.create<arith::ExtUIOp>(loc, builder.getI64Type(), load); 869 load = 870 builder.create<arith::IndexCastOp>(loc, builder.getIndexType(), load); 871 } 872 return load; 873 } 874 875 /// Generates an invariant value. 876 static Value genInvariantValue(Merger &merger, CodeGen &codegen, 877 OpBuilder &builder, unsigned exp) { 878 Value val = merger.exp(exp).val; 879 if (codegen.curVecLength > 1) 880 return genVectorInvariantValue(codegen, builder, val); 881 return val; 882 } 883 884 /// Generates an address computation "sz * p + i". 885 static Value genAddress(CodeGen &codegen, OpBuilder &builder, Location loc, 886 Value size, Value p, Value i) { 887 Value mul = builder.create<arith::MulIOp>(loc, size, p); 888 if (auto vtp = i.getType().dyn_cast<VectorType>()) { 889 Value inv = 890 builder.create<arith::IndexCastOp>(loc, vtp.getElementType(), mul); 891 mul = genVectorInvariantValue(codegen, builder, inv); 892 } 893 return builder.create<arith::AddIOp>(loc, mul, i); 894 } 895 896 /// Generates an index value. 897 static Value genIndexValue(CodeGen &codegen, OpBuilder &builder, unsigned idx, 898 unsigned ldx) { 899 Value ival = codegen.loops[idx]; 900 Type itype = ival.getType(); 901 // During vectorization, we either encounter: 902 // (1) indices already in vector form, as in ... = ind[lo:hi], good to go, or 903 // (2) single index, as in ... = i, must convert to [i, i+1, ...] for inner i. 904 unsigned vl = codegen.curVecLength; 905 if (vl > 1 && !itype.isa<VectorType>()) { 906 Location loc = ival.getLoc(); 907 VectorType vtp = vectorType(codegen, itype); 908 ival = builder.create<vector::BroadcastOp>(loc, vtp, ival); 909 if (idx == ldx) { 910 Value incr; 911 if (vtp.isScalable()) { 912 Type stepvty = vectorType(codegen, builder.getI64Type()); 913 Value stepv = builder.create<LLVM::StepVectorOp>(loc, stepvty); 914 incr = builder.create<arith::IndexCastOp>(loc, vtp, stepv); 915 } else { 916 SmallVector<APInt, 4> integers; 917 for (unsigned i = 0; i < vl; i++) 918 integers.push_back(APInt(/*width=*/64, i)); 919 auto values = DenseElementsAttr::get(vtp, integers); 920 incr = builder.create<arith::ConstantOp>(loc, vtp, values); 921 } 922 ival = builder.create<arith::AddIOp>(loc, ival, incr); 923 } 924 } 925 return ival; 926 } 927 928 /// Semi-ring branches are simply inlined by the sparse compiler. Prior 929 /// analysis has verified that all computations are "local" to the inlined 930 /// branch or otherwise invariantly defined outside the loop nest, with the 931 /// exception of index computations, which need to be relinked to actual 932 /// inlined cloned code. 933 static Value relinkBranch(CodeGen &codegen, RewriterBase &rewriter, 934 Block *block, Value e, unsigned ldx) { 935 if (Operation *def = e.getDefiningOp()) { 936 if (auto indexOp = dyn_cast<linalg::IndexOp>(def)) 937 return genIndexValue(codegen, rewriter, indexOp.dim(), ldx); 938 if (def->getBlock() == block) { 939 for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) 940 def->setOperand( 941 i, relinkBranch(codegen, rewriter, block, def->getOperand(i), ldx)); 942 } 943 } 944 return e; 945 } 946 947 /// Recursively generates tensor expression. 948 static Value genExp(Merger &merger, CodeGen &codegen, RewriterBase &rewriter, 949 linalg::GenericOp op, unsigned exp, unsigned ldx) { 950 Location loc = op.getLoc(); 951 if (exp == -1u) 952 return Value(); 953 if (merger.exp(exp).kind == Kind::kTensor) 954 return genTensorLoad(merger, codegen, rewriter, op, exp); 955 if (merger.exp(exp).kind == Kind::kInvariant) 956 return genInvariantValue(merger, codegen, rewriter, exp); 957 if (merger.exp(exp).kind == Kind::kIndex) 958 return genIndexValue(codegen, rewriter, merger.exp(exp).index, ldx); 959 Value v0 = 960 genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0, ldx); 961 Value v1 = 962 genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1, ldx); 963 Value ee = merger.buildExp(rewriter, loc, exp, v0, v1); 964 if (ee && (merger.exp(exp).kind == Kind::kUnary || 965 merger.exp(exp).kind == Kind::kBinary || 966 merger.exp(exp).kind == Kind::kBinaryBranch)) 967 ee = relinkBranch(codegen, rewriter, ee.getParentBlock(), ee, ldx); 968 return ee; 969 } 970 971 /// Determines if affine expression is invariant. 972 static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a, 973 unsigned ldx, bool &atLevel) { 974 switch (a.getKind()) { 975 case AffineExprKind::DimId: { 976 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 977 if (idx == ldx) 978 atLevel = true; 979 return codegen.loops[idx] != nullptr; // no longer in play? 980 } 981 case AffineExprKind::Add: 982 case AffineExprKind::Mul: { 983 auto binOp = a.cast<AffineBinaryOpExpr>(); 984 return isInvariantAffine(codegen, binOp.getLHS(), ldx, atLevel) && 985 isInvariantAffine(codegen, binOp.getRHS(), ldx, atLevel); 986 } 987 default: 988 return true; 989 } 990 } 991 992 /// Hoists loop invariant tensor loads for which indices have been exhausted. 993 static void genInvariants(Merger &merger, CodeGen &codegen, OpBuilder &builder, 994 linalg::GenericOp op, unsigned exp, unsigned ldx, 995 bool atStart, Kind last = Kind::kTensor) { 996 if (exp == -1u) 997 return; 998 if (merger.exp(exp).kind == Kind::kTensor) { 999 // Inspect tensor indices. 1000 bool atLevel = ldx == -1u; 1001 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 1002 auto map = op.getTiedIndexingMap(t); 1003 auto enc = getSparseTensorEncoding(t->get().getType()); 1004 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 1005 AffineExpr a = map.getResult(perm(enc, d)); 1006 if (!isInvariantAffine(codegen, a, ldx, atLevel)) 1007 return; // still in play 1008 } 1009 // All exhausted at this level (atLevel denotes exactly at this level). 1010 if (!atLevel) 1011 return; 1012 OpOperand *lhs = op.getOutputOperand(0); 1013 if (lhs == t) { 1014 // Start or end a scalarized reduction 1015 if (atStart) { 1016 Value load = genTensorLoad(merger, codegen, builder, op, exp); 1017 codegen.redKind = getReduction(last); 1018 codegen.redExp = exp; 1019 updateReduc(merger, codegen, load); 1020 } else { 1021 Value redVal = codegen.redVal; 1022 updateReduc(merger, codegen, Value()); 1023 codegen.redExp = -1u; 1024 codegen.redKind = kNoReduc; 1025 genTensorStore(merger, codegen, builder, op, exp, redVal); 1026 } 1027 } else { 1028 // Start or end loop invariant hoisting of a tensor load. 1029 merger.exp(exp).val = 1030 atStart ? genTensorLoad(merger, codegen, builder, op, exp) : Value(); 1031 } 1032 } else if (merger.exp(exp).kind != Kind::kInvariant && 1033 merger.exp(exp).kind != Kind::kIndex) { 1034 // Traverse into the binary operations. Note that we only hoist 1035 // tensor loads, since subsequent MLIR/LLVM passes know how to 1036 // deal with all other kinds of derived loop invariants. 1037 Kind last = merger.exp(exp).kind; 1038 unsigned e0 = merger.exp(exp).children.e0; 1039 unsigned e1 = merger.exp(exp).children.e1; 1040 genInvariants(merger, codegen, builder, op, e0, ldx, atStart, last); 1041 genInvariants(merger, codegen, builder, op, e1, ldx, atStart, last); 1042 } 1043 } 1044 1045 /// Generates an expanded access pattern in innermost dimension. 1046 static void genExpansion(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1047 linalg::GenericOp op, unsigned at, bool atStart) { 1048 OpOperand *lhs = codegen.sparseOut; 1049 if (!lhs || codegen.outerParNest != op.getRank(lhs) - 1 || 1050 at != codegen.outerParNest) 1051 return; // not needed at this level 1052 // Generate start or end of an expanded access pattern. 1053 Value tensor = lhs->get(); 1054 Location loc = op.getLoc(); 1055 if (atStart) { 1056 auto dynShape = {ShapedType::kDynamicSize}; 1057 Type etp = tensor.getType().cast<ShapedType>().getElementType(); 1058 Type t1 = MemRefType::get(dynShape, etp); 1059 Type t2 = MemRefType::get(dynShape, builder.getI1Type()); 1060 Type t3 = MemRefType::get(dynShape, builder.getIndexType()); 1061 Type t4 = builder.getIndexType(); 1062 auto res = 1063 builder.create<ExpandOp>(loc, TypeRange({t1, t2, t3, t4}), tensor); 1064 assert(res.getNumResults() == 4); 1065 assert(!codegen.expValues); 1066 codegen.expValues = res.getResult(0); 1067 codegen.expFilled = res.getResult(1); 1068 codegen.expAdded = res.getResult(2); 1069 codegen.expCount = res.getResult(3); 1070 } else { 1071 assert(codegen.expValues); 1072 builder.create<CompressOp>(loc, tensor, codegen.lexIdx, codegen.expValues, 1073 codegen.expFilled, codegen.expAdded, 1074 codegen.expCount); 1075 codegen.expValues = codegen.expFilled = codegen.expAdded = 1076 codegen.expCount = Value(); 1077 } 1078 } 1079 1080 /// Generates initialization code for the subsequent loop sequence at 1081 /// current index level. Returns true if the loop sequence needs to 1082 /// maintain the universal index. 1083 static bool genInit(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1084 linalg::GenericOp op, std::vector<unsigned> &topSort, 1085 unsigned at, BitVector &inits) { 1086 bool needsUniv = false; 1087 Location loc = op.getLoc(); 1088 unsigned idx = topSort[at]; 1089 1090 // Initialize sparse positions. 1091 for (unsigned b = 0, be = inits.size(); b < be; b++) { 1092 if (inits[b]) { 1093 unsigned tensor = merger.tensor(b); 1094 assert(idx == merger.index(b)); 1095 if (merger.isDim(b, Dim::kSparse)) { 1096 // Initialize sparse index. 1097 unsigned pat = at; 1098 for (; pat != 0; pat--) { 1099 if (codegen.pidxs[tensor][topSort[pat - 1]]) 1100 break; 1101 } 1102 Value ptr = codegen.pointers[tensor][idx]; 1103 Value one = constantIndex(builder, loc, 1); 1104 Value p0 = (pat == 0) ? constantIndex(builder, loc, 0) 1105 : codegen.pidxs[tensor][topSort[pat - 1]]; 1106 codegen.pidxs[tensor][idx] = genLoad(codegen, builder, loc, ptr, p0); 1107 Value p1 = builder.create<arith::AddIOp>(loc, p0, one); 1108 codegen.highs[tensor][idx] = genLoad(codegen, builder, loc, ptr, p1); 1109 } else { 1110 // Dense index still in play. 1111 needsUniv = true; 1112 } 1113 } 1114 } 1115 1116 // Initialize the universal dense index. 1117 codegen.loops[idx] = constantIndex(builder, loc, 0); 1118 return needsUniv; 1119 } 1120 1121 /// Returns vectorization strategy. Any implicit inner loop in the Linalg 1122 /// operation is a candidate. Whether it is actually converted to SIMD code 1123 /// depends on the requested strategy. 1124 static bool isVectorFor(CodeGen &codegen, bool isInner, bool isReduction, 1125 bool isSparse) { 1126 // Reject vectorization of sparse output, unless innermost is reduction. 1127 if (codegen.sparseOut && !isReduction) 1128 return false; 1129 // Inspect strategy. 1130 switch (codegen.options.vectorizationStrategy) { 1131 case SparseVectorizationStrategy::kNone: 1132 return false; 1133 case SparseVectorizationStrategy::kDenseInnerLoop: 1134 return isInner && !isSparse; 1135 case SparseVectorizationStrategy::kAnyStorageInnerLoop: 1136 return isInner; 1137 } 1138 llvm_unreachable("unexpected vectorization strategy"); 1139 } 1140 1141 /// Returns parallelization strategy. Any implicit loop in the Linalg operation 1142 /// that is marked "parallel" is a candidate. Whether it is actually converted 1143 /// to a parallel operation depends on the requested strategy. 1144 static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction, 1145 bool isSparse, bool isVector) { 1146 // Reject parallelization of sparse output. 1147 if (codegen.sparseOut) 1148 return false; 1149 // Inspect strategy. 1150 switch (codegen.options.parallelizationStrategy) { 1151 case SparseParallelizationStrategy::kNone: 1152 return false; 1153 case SparseParallelizationStrategy::kDenseOuterLoop: 1154 return isOuter && !isSparse && !isReduction && !isVector; 1155 case SparseParallelizationStrategy::kAnyStorageOuterLoop: 1156 return isOuter && !isReduction && !isVector; 1157 case SparseParallelizationStrategy::kDenseAnyLoop: 1158 return !isSparse && !isReduction && !isVector; 1159 case SparseParallelizationStrategy::kAnyStorageAnyLoop: 1160 return !isReduction && !isVector; 1161 } 1162 llvm_unreachable("unexpected parallelization strategy"); 1163 } 1164 1165 /// Checks unit stride for dense tensors. The iteration graph may have ignored 1166 /// dense access patterns in order to avoid cycles (sparse access patterns are 1167 /// always placed innermost), but that means dense access has become strided. 1168 /// This prevents effective vectorization. 1169 static bool denseUnitStrides(Merger &merger, linalg::GenericOp op, 1170 unsigned idx) { 1171 for (OpOperand *t : op.getInputAndOutputOperands()) { 1172 if (!getSparseTensorEncoding(t->get().getType())) { 1173 auto map = op.getTiedIndexingMap(t); 1174 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 1175 AffineExpr a = map.getResult(d); 1176 // Report non-unit stride if innermost index appears at an outer 1177 // dimension (true non-unit stride) or if the innermost index appears 1178 // in a compound subscript in the innermost dimension. Even if the 1179 // latter is unit stride, it does not play well with scatter/gather. 1180 // TODO: accept unit stride affine innermost like a[i,j+k+1]? 1181 if (a.isFunctionOfDim(idx) && 1182 ((d != rank - 1) || (a.getKind() != AffineExprKind::DimId))) 1183 return false; 1184 } 1185 } 1186 } 1187 return true; 1188 } 1189 1190 /// Generates a for-loop on a single index. 1191 static Operation *genFor(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1192 linalg::GenericOp op, bool isOuter, bool isInner, 1193 unsigned idx, BitVector &indices) { 1194 unsigned fb = indices.find_first(); 1195 unsigned tensor = merger.tensor(fb); 1196 assert(idx == merger.index(fb)); 1197 auto iteratorTypes = op.iterator_types().getValue(); 1198 bool isReduction = isReductionIterator(iteratorTypes[idx]); 1199 bool isSparse = merger.isDim(fb, Dim::kSparse); 1200 bool isVector = isVectorFor(codegen, isInner, isReduction, isSparse) && 1201 denseUnitStrides(merger, op, idx); 1202 bool isParallel = 1203 isParallelFor(codegen, isOuter, isReduction, isSparse, isVector); 1204 1205 // Prepare vector length. 1206 if (isVector) 1207 codegen.curVecLength = codegen.options.vectorLength; 1208 1209 // Loop bounds and increment. 1210 Location loc = op.getLoc(); 1211 Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx]; 1212 Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx]; 1213 Value step = constantIndex(builder, loc, codegen.curVecLength); 1214 if (isVector && codegen.options.enableVLAVectorization) { 1215 Value vscale = builder.create<vector::VectorScaleOp>( 1216 loc, IndexType::get(builder.getContext())); 1217 step = builder.create<arith::MulIOp>(loc, vscale, step); 1218 } 1219 1220 // Emit a parallel loop. 1221 if (isParallel) { 1222 assert(!isVector); 1223 scf::ParallelOp parOp = builder.create<scf::ParallelOp>(loc, lo, hi, step); 1224 if (isSparse) 1225 codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0]; 1226 else 1227 codegen.loops[idx] = parOp.getInductionVars()[0]; 1228 builder.setInsertionPointToStart(parOp.getBody()); 1229 return parOp; 1230 } 1231 1232 // Emit a sequential or vector loop. 1233 SmallVector<Value, 4> operands; 1234 if (codegen.redVal) { 1235 // In a vector loop, bring reduction into SIMD form, if not already. 1236 if (isVector && !codegen.redVal.getType().isa<VectorType>()) { 1237 VectorType vtp = vectorType(codegen, codegen.redVal.getType()); 1238 Value vred = genVectorReducInit(codegen, builder, loc, vtp); 1239 updateReduc(merger, codegen, vred); 1240 } 1241 operands.push_back(codegen.redVal); 1242 } 1243 if (codegen.expValues) 1244 operands.push_back(codegen.expCount); 1245 scf::ForOp forOp = builder.create<scf::ForOp>(loc, lo, hi, step, operands); 1246 if (codegen.redVal) 1247 updateReduc(merger, codegen, forOp.getRegionIterArgs().front()); 1248 if (codegen.expValues) 1249 codegen.expCount = forOp.getRegionIterArgs().back(); 1250 // Assign induction variable to sparse or dense index. 1251 Value iv = forOp.getInductionVar(); 1252 if (isSparse) 1253 codegen.pidxs[tensor][idx] = iv; 1254 else 1255 codegen.loops[idx] = iv; 1256 builder.setInsertionPointToStart(forOp.getBody()); 1257 // Share vector iteration mask between all subsequent loads/stores. 1258 if (isVector) 1259 codegen.curVecMask = genVectorMask(codegen, builder, iv, lo, hi, step); 1260 return forOp; 1261 } 1262 1263 /// Emit a while-loop for co-iteration over multiple indices. 1264 static Operation *genWhile(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1265 linalg::GenericOp op, unsigned idx, bool needsUniv, 1266 BitVector &indices) { 1267 SmallVector<Type, 4> types; 1268 SmallVector<Value, 4> operands; 1269 // Construct the while-loop with a parameter for each index. 1270 Type indexType = builder.getIndexType(); 1271 for (unsigned b = 0, be = indices.size(); b < be; b++) { 1272 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1273 unsigned tensor = merger.tensor(b); 1274 assert(idx == merger.index(b)); 1275 types.push_back(indexType); 1276 operands.push_back(codegen.pidxs[tensor][idx]); 1277 } 1278 } 1279 if (codegen.redVal) { 1280 types.push_back(codegen.redVal.getType()); 1281 operands.push_back(codegen.redVal); 1282 } 1283 if (codegen.expValues) { 1284 types.push_back(indexType); 1285 operands.push_back(codegen.expCount); 1286 } 1287 if (needsUniv) { 1288 types.push_back(indexType); 1289 operands.push_back(codegen.loops[idx]); 1290 } 1291 assert(types.size() == operands.size()); 1292 Location loc = op.getLoc(); 1293 scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands); 1294 1295 SmallVector<Location> locs(types.size(), loc); 1296 Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs); 1297 Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs); 1298 1299 // Build the "before" region, which effectively consists 1300 // of a conjunction of "i < upper" tests on all induction. 1301 builder.setInsertionPointToStart(&whileOp.getBefore().front()); 1302 Value cond; 1303 unsigned o = 0; 1304 for (unsigned b = 0, be = indices.size(); b < be; b++) { 1305 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1306 unsigned tensor = merger.tensor(b); 1307 assert(idx == merger.index(b)); 1308 Value op1 = before->getArgument(o); 1309 Value op2 = codegen.highs[tensor][idx]; 1310 Value opc = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 1311 op1, op2); 1312 cond = cond ? builder.create<arith::AndIOp>(loc, cond, opc) : opc; 1313 codegen.pidxs[tensor][idx] = after->getArgument(o++); 1314 } 1315 } 1316 if (codegen.redVal) 1317 updateReduc(merger, codegen, after->getArgument(o++)); 1318 if (codegen.expValues) 1319 codegen.expCount = after->getArgument(o++); 1320 if (needsUniv) 1321 codegen.loops[idx] = after->getArgument(o++); 1322 assert(o == operands.size()); 1323 builder.create<scf::ConditionOp>(loc, cond, before->getArguments()); 1324 builder.setInsertionPointToStart(&whileOp.getAfter().front()); 1325 return whileOp; 1326 } 1327 1328 /// Generates a for-loop or a while-loop, depending on whether it implements 1329 /// singleton iteration or co-iteration over the given conjunction. 1330 static Operation *genLoop(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1331 linalg::GenericOp op, std::vector<unsigned> &topSort, 1332 unsigned at, bool needsUniv, BitVector &indices) { 1333 unsigned idx = topSort[at]; 1334 if (indices.count() == 1) { 1335 bool isOuter = at == 0; 1336 bool isInner = at == topSort.size() - 1; 1337 return genFor(merger, codegen, builder, op, isOuter, isInner, idx, indices); 1338 } 1339 return genWhile(merger, codegen, builder, op, idx, needsUniv, indices); 1340 } 1341 1342 /// Generates the local variables for this loop, consisting of the sparse 1343 /// indices, restored universal dense index, and dense positions. 1344 static void genLocals(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1345 linalg::GenericOp op, std::vector<unsigned> &topSort, 1346 unsigned at, bool needsUniv, BitVector &locals) { 1347 Location loc = op.getLoc(); 1348 unsigned idx = topSort[at]; 1349 1350 // Initialize sparse indices. 1351 Value min; 1352 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1353 if (locals[b] && merger.isDim(b, Dim::kSparse)) { 1354 unsigned tensor = merger.tensor(b); 1355 assert(idx == merger.index(b)); 1356 Value ptr = codegen.indices[tensor][idx]; 1357 Value s = codegen.pidxs[tensor][idx]; 1358 Value load = genLoad(codegen, builder, loc, ptr, s); 1359 codegen.idxs[tensor][idx] = load; 1360 if (!needsUniv) { 1361 if (min) { 1362 Value cmp = builder.create<arith::CmpIOp>( 1363 loc, arith::CmpIPredicate::ult, load, min); 1364 min = builder.create<arith::SelectOp>(loc, cmp, load, min); 1365 } else { 1366 min = load; 1367 } 1368 } 1369 } 1370 } 1371 1372 // Merge dense universal index over minimum. 1373 if (min) { 1374 assert(!needsUniv); 1375 codegen.loops[idx] = min; 1376 } 1377 1378 // Initialize dense positions. Note that we generate dense indices of the 1379 // output tensor unconditionally, since they may not appear in the lattice, 1380 // but may be needed for linearized codegen. 1381 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1382 if ((locals[b] || merger.isOutTensor(b, idx)) && 1383 merger.isDim(b, Dim::kDense)) { 1384 unsigned tensor = merger.tensor(b); 1385 assert(idx == merger.index(b)); 1386 unsigned pat = at; 1387 for (; pat != 0; pat--) 1388 if (codegen.pidxs[tensor][topSort[pat - 1]]) 1389 break; 1390 Value p = (pat == 0) ? constantIndex(builder, loc, 0) 1391 : codegen.pidxs[tensor][topSort[pat - 1]]; 1392 codegen.pidxs[tensor][idx] = genAddress( 1393 codegen, builder, loc, codegen.sizes[idx], p, codegen.loops[idx]); 1394 } 1395 } 1396 1397 // Move the insertion indices in lexicographic index order. During access 1398 // pattern expansion, we can skip setting the innermost dimension. 1399 if (codegen.sparseOut && !codegen.expValues) { 1400 Value pos = constantIndex(builder, loc, at); 1401 builder.create<memref::StoreOp>(loc, codegen.loops[idx], codegen.lexIdx, 1402 pos); 1403 } 1404 } 1405 1406 /// Generates the induction structure for a while-loop. 1407 static void genWhileInduction(Merger &merger, CodeGen &codegen, 1408 OpBuilder &builder, linalg::GenericOp op, 1409 unsigned idx, bool needsUniv, 1410 BitVector &induction, scf::WhileOp whileOp) { 1411 Location loc = op.getLoc(); 1412 // Finalize each else branch of all if statements. 1413 if (codegen.redVal || codegen.expValues) { 1414 while (auto ifOp = dyn_cast_or_null<scf::IfOp>( 1415 builder.getInsertionBlock()->getParentOp())) { 1416 unsigned y = 0; 1417 SmallVector<Value, 4> yields; 1418 if (codegen.redVal) { 1419 yields.push_back(codegen.redVal); 1420 updateReduc(merger, codegen, ifOp.getResult(y++)); 1421 } 1422 if (codegen.expValues) { 1423 yields.push_back(codegen.expCount); 1424 codegen.expCount = ifOp->getResult(y++); 1425 } 1426 assert(y == yields.size()); 1427 builder.create<scf::YieldOp>(loc, yields); 1428 builder.setInsertionPointAfter(ifOp); 1429 } 1430 } 1431 builder.setInsertionPointToEnd(&whileOp.getAfter().front()); 1432 // Finalize the induction. Note that the induction could be performed 1433 // in the individual if-branches to avoid re-evaluating the conditions. 1434 // However, that would result in a rather elaborate forest of yield 1435 // instructions during code generation. Moreover, performing the induction 1436 // after the if-statements more closely resembles code generated by TACO. 1437 unsigned o = 0; 1438 SmallVector<Value, 4> operands; 1439 Value one = constantIndex(builder, loc, 1); 1440 for (unsigned b = 0, be = induction.size(); b < be; b++) { 1441 if (induction[b] && merger.isDim(b, Dim::kSparse)) { 1442 unsigned tensor = merger.tensor(b); 1443 assert(idx == merger.index(b)); 1444 Value op1 = codegen.idxs[tensor][idx]; 1445 Value op2 = codegen.loops[idx]; 1446 Value op3 = codegen.pidxs[tensor][idx]; 1447 Value cmp = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 1448 op1, op2); 1449 Value add = builder.create<arith::AddIOp>(loc, op3, one); 1450 operands.push_back(builder.create<arith::SelectOp>(loc, cmp, add, op3)); 1451 codegen.pidxs[tensor][idx] = whileOp->getResult(o++); 1452 } 1453 } 1454 if (codegen.redVal) { 1455 operands.push_back(codegen.redVal); 1456 updateReduc(merger, codegen, whileOp->getResult(o++)); 1457 } 1458 if (codegen.expValues) { 1459 operands.push_back(codegen.expCount); 1460 codegen.expCount = whileOp->getResult(o++); 1461 } 1462 if (needsUniv) { 1463 operands.push_back( 1464 builder.create<arith::AddIOp>(loc, codegen.loops[idx], one)); 1465 codegen.loops[idx] = whileOp->getResult(o++); 1466 } 1467 assert(o == operands.size()); 1468 builder.create<scf::YieldOp>(loc, operands); 1469 builder.setInsertionPointAfter(whileOp); 1470 } 1471 1472 /// Generates the induction structure for a for-loop. 1473 static void genForInduction(Merger &merger, CodeGen &codegen, 1474 OpBuilder &builder, linalg::GenericOp op, 1475 Operation *loop) { 1476 Location loc = op.getLoc(); 1477 unsigned o = 0; 1478 SmallVector<Value, 4> operands; 1479 if (codegen.redVal) { 1480 operands.push_back(codegen.redVal); 1481 updateReduc(merger, codegen, loop->getResult(o++)); 1482 } 1483 if (codegen.expValues) { 1484 operands.push_back(codegen.expCount); 1485 codegen.expCount = loop->getResult(o++); 1486 } 1487 assert(o == operands.size()); 1488 if (o > 0) 1489 builder.create<scf::YieldOp>(loc, operands); 1490 builder.setInsertionPointAfter(loop); 1491 } 1492 1493 /// Generates a single if-statement within a while-loop. 1494 static scf::IfOp genIf(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1495 linalg::GenericOp op, unsigned idx, 1496 BitVector &conditions) { 1497 Location loc = op.getLoc(); 1498 SmallVector<Type, 4> types; 1499 Value cond; 1500 for (unsigned b = 0, be = conditions.size(); b < be; b++) { 1501 if (conditions[b]) { 1502 unsigned tensor = merger.tensor(b); 1503 assert(idx == merger.index(b)); 1504 Value clause; 1505 if (merger.isDim(b, Dim::kSparse)) { 1506 Value op1 = codegen.idxs[tensor][idx]; 1507 Value op2 = codegen.loops[idx]; 1508 clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 1509 op1, op2); 1510 } else { 1511 clause = constantI1(builder, loc, true); 1512 } 1513 cond = cond ? builder.create<arith::AndIOp>(loc, cond, clause) : clause; 1514 } 1515 } 1516 if (codegen.redVal) 1517 types.push_back(codegen.redVal.getType()); 1518 if (codegen.expValues) 1519 types.push_back(builder.getIndexType()); 1520 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true); 1521 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 1522 return ifOp; 1523 } 1524 1525 /// Generates end of true branch of if-statement within a while-loop. 1526 static void endIf(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1527 linalg::GenericOp op, scf::IfOp ifOp, Operation *loop, 1528 Value redInput, Value cntInput) { 1529 SmallVector<Value, 4> operands; 1530 if (codegen.redVal) { 1531 operands.push_back(codegen.redVal); 1532 updateReduc(merger, codegen, redInput); 1533 } 1534 if (codegen.expValues) { 1535 operands.push_back(codegen.expCount); 1536 codegen.expCount = cntInput; 1537 } 1538 if (!operands.empty()) 1539 builder.create<scf::YieldOp>(op.getLoc(), operands); 1540 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 1541 } 1542 1543 //===----------------------------------------------------------------------===// 1544 // Sparse compiler synthesis methods (loop sequence). 1545 //===----------------------------------------------------------------------===// 1546 1547 /// Starts a loop sequence at given level. Returns true if 1548 /// the universal loop index must be maintained at this level. 1549 static bool startLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1550 linalg::GenericOp op, std::vector<unsigned> &topSort, 1551 unsigned exp, unsigned at, unsigned idx, unsigned ldx, 1552 unsigned lts) { 1553 assert(codegen.curVecLength == 1); 1554 assert(!codegen.loops[idx]); 1555 // Emit invariants at this loop sequence level. 1556 genInvariants(merger, codegen, builder, op, exp, ldx, /*atStart=*/true); 1557 // Emit access pattern expansion for sparse tensor output. 1558 genExpansion(merger, codegen, builder, op, at, /*atStart=*/true); 1559 // Emit further intitialization at this loop sequence level. 1560 unsigned l0 = merger.set(lts)[0]; 1561 bool needsUniv = 1562 genInit(merger, codegen, builder, op, topSort, at, merger.lat(l0).bits); 1563 // Maintain the universal index only if it is actually 1564 // consumed by a subsequent lattice point. 1565 if (needsUniv) { 1566 unsigned lsize = merger.set(lts).size(); 1567 for (unsigned i = 1; i < lsize; i++) { 1568 unsigned li = merger.set(lts)[i]; 1569 if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) 1570 return true; 1571 } 1572 } 1573 return false; 1574 } 1575 1576 /// Starts a single loop in current sequence. 1577 static Operation *startLoop(Merger &merger, CodeGen &codegen, 1578 OpBuilder &builder, linalg::GenericOp op, 1579 std::vector<unsigned> &topSort, unsigned at, 1580 unsigned li, bool needsUniv) { 1581 assert(codegen.curVecLength == 1); 1582 // Emit the for/while-loop control. 1583 Operation *loop = genLoop(merger, codegen, builder, op, topSort, at, 1584 needsUniv, merger.lat(li).simple); 1585 // Emit the locals for this loop. 1586 genLocals(merger, codegen, builder, op, topSort, at, needsUniv, 1587 merger.lat(li).bits); 1588 return loop; 1589 } 1590 1591 /// Ends a single loop in current sequence. Returns new values for needsUniv. 1592 static bool endLoop(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1593 linalg::GenericOp op, Operation *loop, unsigned idx, 1594 unsigned li, bool needsUniv) { 1595 codegen.curVecLength = 1; 1596 // End a while-loop. 1597 if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) { 1598 genWhileInduction(merger, codegen, builder, op, idx, needsUniv, 1599 merger.lat(li).bits, whileOp); 1600 return needsUniv; 1601 } 1602 // End a for-loop. 1603 genForInduction(merger, codegen, builder, op, loop); 1604 return false; 1605 } 1606 1607 /// Ends a loop sequence at given level. 1608 static void endLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1609 linalg::GenericOp op, unsigned exp, unsigned at, 1610 unsigned idx, unsigned ldx) { 1611 assert(codegen.curVecLength == 1); 1612 codegen.loops[idx] = Value(); 1613 // Bring a pending reduction back from SIMD form when sequence ends. 1614 if (codegen.redVal) 1615 if (auto vtp = codegen.redVal.getType().dyn_cast<VectorType>()) 1616 updateReduc(merger, codegen, 1617 genVectorReducEnd(codegen, builder, op.getLoc(), vtp)); 1618 // Unmark bookkeeping of invariants and loop index. 1619 genInvariants(merger, codegen, builder, op, exp, ldx, /*atStart=*/false); 1620 // Finalize access pattern expansion for sparse tensor output. 1621 genExpansion(merger, codegen, builder, op, at, /*atStart=*/false); 1622 } 1623 1624 /// Recursively generates code while computing iteration lattices in order 1625 /// to manage the complexity of implementing co-iteration over unions 1626 /// and intersections of sparse iterations spaces. 1627 static void genStmt(Merger &merger, CodeGen &codegen, RewriterBase &rewriter, 1628 linalg::GenericOp op, std::vector<unsigned> &topSort, 1629 unsigned exp, unsigned at) { 1630 // At each leaf, assign remaining tensor (sub)expression to output tensor. 1631 if (at == topSort.size()) { 1632 unsigned ldx = topSort[at - 1]; 1633 Value rhs = genExp(merger, codegen, rewriter, op, exp, ldx); 1634 genTensorStore(merger, codegen, rewriter, op, exp, rhs); 1635 return; 1636 } 1637 1638 // Construct iteration lattices for current loop index, with L0 at top. 1639 unsigned idx = topSort[at]; 1640 unsigned ldx = at == 0 ? -1u : topSort[at - 1]; 1641 unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx)); 1642 1643 // Start a loop sequence. 1644 bool needsUniv = startLoopSeq(merger, codegen, rewriter, op, topSort, exp, at, 1645 idx, ldx, lts); 1646 1647 // Emit a loop for every lattice point L0 >= Li in this loop sequence. 1648 unsigned lsize = merger.set(lts).size(); 1649 for (unsigned i = 0; i < lsize; i++) { 1650 // Start a loop. 1651 unsigned li = merger.set(lts)[i]; 1652 Operation *loop = 1653 startLoop(merger, codegen, rewriter, op, topSort, at, li, needsUniv); 1654 1655 // Visit all lattices points with Li >= Lj to generate the 1656 // loop-body, possibly with if statements for coiteration. 1657 Value redInput = codegen.redVal; 1658 Value cntInput = codegen.expCount; 1659 bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr; 1660 for (unsigned j = 0; j < lsize; j++) { 1661 unsigned lj = merger.set(lts)[j]; 1662 unsigned ej = merger.lat(lj).exp; 1663 if (li == lj || merger.latGT(li, lj)) { 1664 // Recurse into body of each branch. 1665 if (isWhile) { 1666 scf::IfOp ifOp = 1667 genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple); 1668 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1669 endIf(merger, codegen, rewriter, op, ifOp, loop, redInput, cntInput); 1670 } else { 1671 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1672 } 1673 } 1674 } 1675 1676 // End a loop. 1677 needsUniv = 1678 endLoop(merger, codegen, rewriter, op, loop, idx, li, needsUniv); 1679 } 1680 1681 // End a loop sequence. 1682 endLoopSeq(merger, codegen, rewriter, op, exp, at, idx, ldx); 1683 } 1684 1685 /// Converts the result computed by the sparse kernel into the required form. 1686 static void genResult(Merger &merger, CodeGen &codegen, RewriterBase &rewriter, 1687 linalg::GenericOp op) { 1688 OpOperand *lhs = op.getOutputOperand(0); 1689 Type resType = lhs->get().getType(); 1690 if (getSparseTensorEncoding(resType)) { 1691 // The sparse tensor rematerializes from the original sparse tensor's 1692 // underlying sparse storage format. 1693 rewriter.replaceOpWithNewOp<LoadOp>(op, resType, lhs->get(), 1694 codegen.sparseOut == lhs); 1695 } else { 1696 // To rematerialize an non-annotated tensor, simply load it 1697 // from the bufferized value. 1698 Value val = codegen.buffers.back(); // value array 1699 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val); 1700 } 1701 } 1702 1703 //===----------------------------------------------------------------------===// 1704 // Sparse compiler rewriting methods. 1705 //===----------------------------------------------------------------------===// 1706 1707 namespace { 1708 1709 /// Sparse rewriting rule for generic Lingalg operation. 1710 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> { 1711 public: 1712 GenericOpSparsifier(MLIRContext *context, SparsificationOptions o) 1713 : OpRewritePattern<linalg::GenericOp>(context), options(o) {} 1714 1715 LogicalResult matchAndRewrite(linalg::GenericOp op, 1716 PatternRewriter &rewriter) const override { 1717 // Detects sparse annotations and translate the per-dimension sparsity 1718 // information for all tensors to loop indices in the kernel. 1719 assert(op.getNumOutputs() == 1); 1720 unsigned numTensors = op.getNumInputsAndOutputs(); 1721 unsigned numLoops = op.iterator_types().getValue().size(); 1722 Merger merger(numTensors, numLoops); 1723 if (!findSparseAnnotations(merger, op)) 1724 return failure(); 1725 1726 // Computes a topologically sorted iteration graph to ensure tensors 1727 // are visited in natural index order. Gradually relaxes the considered 1728 // constraints until an acyclic iteration graph results, such that sparse 1729 // code generation can proceed. As a last resort, an attempt is made 1730 // to resolve cycles by inserting a conversion. 1731 std::vector<unsigned> topSort; 1732 if (!computeIterationGraph(merger, op, topSort, SortMask::kIncludeAll) && 1733 !computeIterationGraph(merger, op, topSort, SortMask::kIncludeUndef) && 1734 !computeIterationGraph(merger, op, topSort, SortMask::kIncludeDense) && 1735 !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly)) { 1736 return resolveCycle(merger, rewriter, op); 1737 } 1738 1739 // Builds the tensor expression for the Linalg operation in SSA form. 1740 Optional<unsigned> optExp = merger.buildTensorExpFromLinalg(op); 1741 if (!optExp.has_value()) 1742 return failure(); 1743 unsigned exp = optExp.value(); 1744 1745 // Rejects an inadmissable tensor expression. 1746 OpOperand *sparseOut = nullptr; 1747 unsigned outerParNest = 0; 1748 if (!isAdmissableTensorExp(merger, op, topSort, exp, &sparseOut, 1749 outerParNest)) 1750 return failure(); 1751 1752 // Recursively generates code. 1753 merger.setHasSparseOut(sparseOut != nullptr); 1754 CodeGen codegen(options, numTensors, numLoops, sparseOut, outerParNest); 1755 genBuffers(merger, codegen, rewriter, op); 1756 genStmt(merger, codegen, rewriter, op, topSort, exp, 0); 1757 genResult(merger, codegen, rewriter, op); 1758 return success(); 1759 } 1760 1761 private: 1762 // Last resort cycle resolution. 1763 LogicalResult resolveCycle(Merger &merger, PatternRewriter &rewriter, 1764 linalg::GenericOp op) const { 1765 // Compute topological sort while leaving out every 1766 // sparse input tensor in succession until an acylic 1767 // iteration graph results. 1768 std::vector<unsigned> topSort; 1769 for (OpOperand *t : op.getInputOperands()) { 1770 unsigned tensor = t->getOperandNumber(); 1771 Value tval = t->get(); 1772 auto srcEnc = getSparseTensorEncoding(tval.getType()); 1773 if (!srcEnc || 1774 !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly, t)) 1775 continue; 1776 // Found an input tensor that resolves the cycle by inserting a 1777 // conversion into a sparse tensor that adheres to the iteration 1778 // graph order. Also releases the temporary sparse tensor. 1779 // 1780 // TODO: investigate fusing the conversion with computation, 1781 // especially if it is a direct yield! 1782 // 1783 auto srcTp = tval.getType().cast<RankedTensorType>(); 1784 auto dstEnc = SparseTensorEncodingAttr::get( 1785 op->getContext(), srcEnc.getDimLevelType(), 1786 permute(getContext(), op.getTiedIndexingMap(t), topSort), // new order 1787 srcEnc.getPointerBitWidth(), srcEnc.getIndexBitWidth()); 1788 auto dstTp = RankedTensorType::get(srcTp.getShape(), 1789 srcTp.getElementType(), dstEnc); 1790 auto convert = rewriter.create<ConvertOp>(tval.getLoc(), dstTp, tval); 1791 op->setOperand(tensor, convert); 1792 rewriter.setInsertionPointAfter(op); 1793 rewriter.create<ReleaseOp>(tval.getLoc(), convert); 1794 return success(); 1795 } 1796 // Cannot be resolved with a single conversion. 1797 // TODO: convert more than one? 1798 return failure(); 1799 } 1800 1801 /// Options to control sparse code generation. 1802 SparsificationOptions options; 1803 }; 1804 1805 /// Sparse rewriting rule for reshape operator. 1806 template <typename ReshapeOp> 1807 struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> { 1808 public: 1809 using OpRewritePattern<ReshapeOp>::OpRewritePattern; 1810 1811 LogicalResult matchAndRewrite(ReshapeOp op, 1812 PatternRewriter &rewriter) const override { 1813 Location loc = op->getLoc(); 1814 auto encDst = getSparseTensorEncoding(op.getResult().getType()); 1815 auto encSrc = getSparseTensorEncoding(op.getSrc().getType()); 1816 // Since a pure dense expansion is very cheap (change of view), for 1817 // a sparse2dense or dense2sparse, we can simply unfuse a sparse 1818 // conversion from the reshape operation itself. 1819 // All other cases are handled elsewhere. 1820 if (encDst && encSrc) { 1821 return failure(); 1822 } else if (encSrc) { 1823 RankedTensorType rtp = 1824 op.getSrc().getType().template cast<RankedTensorType>(); 1825 auto denseTp = 1826 RankedTensorType::get(rtp.getShape(), rtp.getElementType()); 1827 auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc()); 1828 op->setOperand(0, convert); 1829 return success(); 1830 } else if (encDst) { 1831 RankedTensorType rtp = 1832 op.getResult().getType().template cast<RankedTensorType>(); 1833 auto denseTp = 1834 RankedTensorType::get(rtp.getShape(), rtp.getElementType()); 1835 auto reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(), 1836 op.getReassociation()); 1837 Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape); 1838 rewriter.replaceOp(op, convert); 1839 return success(); 1840 } 1841 return failure(); 1842 } 1843 }; 1844 1845 } // namespace 1846 1847 /// Populates the given patterns list with rewriting rules required for 1848 /// the sparsification of linear algebra operations. 1849 void mlir::populateSparsificationPatterns( 1850 RewritePatternSet &patterns, const SparsificationOptions &options) { 1851 patterns.add<GenericOpSparsifier>(patterns.getContext(), options); 1852 patterns.add<ReshapeRewriter<tensor::ExpandShapeOp>, 1853 ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext()); 1854 } 1855