1 //===- Sparsification.cpp - Implementation of sparsification --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements converting sparse tensor types to actual sparse code. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "CodegenUtils.h" 14 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 18 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 19 #include "mlir/Dialect/Func/IR/FuncOps.h" 20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 21 #include "mlir/Dialect/Linalg/IR/Linalg.h" 22 #include "mlir/Dialect/Linalg/Utils/Utils.h" 23 #include "mlir/Dialect/MemRef/IR/MemRef.h" 24 #include "mlir/Dialect/SCF/IR/SCF.h" 25 #include "mlir/Dialect/SCF/Transforms/Transforms.h" 26 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 27 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 28 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 29 #include "mlir/Dialect/Vector/IR/VectorOps.h" 30 #include "mlir/IR/Matchers.h" 31 #include "mlir/IR/TensorEncoding.h" 32 #include "llvm/ADT/SmallBitVector.h" 33 34 using namespace mlir; 35 using namespace mlir::sparse_tensor; 36 37 //===----------------------------------------------------------------------===// 38 // Declarations of data structures. 39 //===----------------------------------------------------------------------===// 40 41 namespace { 42 43 // Iteration graph sorting. 44 enum SortMask { kSparseOnly = 0x0, kIncludeDense = 0x1, kIncludeUndef = 0x2 }; 45 46 // Reduction kinds. 47 enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor }; 48 49 // Code generation. 50 struct CodeGen { 51 CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops, 52 OpOperand *op, unsigned nest) 53 : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors), 54 pointers(numTensors, std::vector<Value>(numLoops)), 55 indices(numTensors, std::vector<Value>(numLoops)), 56 highs(numTensors, std::vector<Value>(numLoops)), 57 pidxs(numTensors, std::vector<Value>(numLoops)), 58 idxs(numTensors, std::vector<Value>(numLoops)), redVal(), sparseOut(op), 59 outerParNest(nest), lexIdx(), lexVal(), expValues(), expFilled(), 60 expAdded(), expCount(), curVecMask() {} 61 /// Sparsification options. 62 SparsificationOptions options; 63 /// Universal dense indices and upper bounds (by index). The loops array 64 /// is updated with the value of the universal dense index in the current 65 /// loop. The sizes array is set once with the inferred dimension sizes. 66 std::vector<Value> loops; 67 std::vector<Value> sizes; 68 /// Buffers for storing dense and sparse numerical values (by tensor). 69 /// This array is set once during bufferization of all tensors. 70 std::vector<Value> buffers; 71 /// Sparse storage schemes (1-D): pointers and indices (by tensor and index). 72 /// This array is set once during bufferization of all sparse tensors. 73 std::vector<std::vector<Value>> pointers; 74 std::vector<std::vector<Value>> indices; 75 /// Sparse iteration information (by tensor and index). These arrays 76 /// are updated to remain current within the current loop. 77 std::vector<std::vector<Value>> highs; 78 std::vector<std::vector<Value>> pidxs; 79 std::vector<std::vector<Value>> idxs; 80 /// Current reduction, updated during code generation. When indices of a 81 /// reduction are exhausted, all inner loops can use a scalarized reduction. 82 unsigned redExp = -1u; 83 Value redVal; 84 Reduction redKind = kNoReduc; 85 // Sparse tensor as output. Implemented either through direct injective 86 // insertion in lexicographic index order (where indices are updated 87 // in the temporary array `lexIdx`) or through access pattern expansion 88 // in the innermost loop nest (`expValues` through `expCount`). 89 OpOperand *sparseOut; 90 unsigned outerParNest; 91 Value lexIdx; 92 Value lexVal; 93 Value expValues; 94 Value expFilled; 95 Value expAdded; 96 Value expCount; 97 // Current vector length and mask. 98 unsigned curVecLength = 1; 99 Value curVecMask; 100 }; 101 102 } // namespace 103 104 //===----------------------------------------------------------------------===// 105 // Sparse compiler analysis methods. 106 //===----------------------------------------------------------------------===// 107 108 /// Helper method to apply dimension ordering permutation. 109 static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d) { 110 if (enc) { 111 auto order = enc.getDimOrdering(); 112 if (order) { 113 assert(order.isPermutation()); 114 return order.getDimPosition(d); 115 } 116 } 117 return d; 118 } 119 120 /// Helper method to translate dim level type to internal representation. 121 static Dim toDim(const SparseTensorEncodingAttr &enc, unsigned d) { 122 if (enc) { 123 SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d]; 124 if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed) 125 return Dim::kSparse; 126 if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton) 127 return Dim::kSingle; 128 } 129 return Dim::kDense; 130 } 131 132 /// Helper method to inspect affine expressions. Rejects cases where the 133 /// same index is used more than once. Also rejects affine expressions 134 /// that are not a direct index for annotated tensors. 135 // TODO: accept more affine cases for sparse tensors 136 static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a, Dim dim, 137 bool isDense) { 138 switch (a.getKind()) { 139 case AffineExprKind::DimId: { 140 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 141 if (!merger.isDim(tensor, idx, Dim::kUndef)) 142 return false; // used more than once 143 merger.setDim(tensor, idx, dim); 144 return true; 145 } 146 case AffineExprKind::Add: 147 case AffineExprKind::Mul: { 148 if (!isDense) 149 return false; 150 auto binOp = a.cast<AffineBinaryOpExpr>(); 151 return findAffine(merger, tensor, binOp.getLHS(), dim, isDense) && 152 findAffine(merger, tensor, binOp.getRHS(), dim, isDense); 153 } 154 case AffineExprKind::Constant: 155 return isDense; 156 default: 157 return false; 158 } 159 } 160 161 /// Helper method to inspect sparse encodings in the tensor types. 162 /// Fills the per-dimension sparsity information for all tensors. 163 /// Returns true if the sparse annotations and affine subscript 164 /// expressions of all tensors are admissable. Returns false if 165 /// no annotations are found or inadmissable constructs occur. 166 static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) { 167 bool annotated = false; 168 for (OpOperand *t : op.getInputAndOutputOperands()) { 169 auto map = op.getTiedIndexingMap(t); 170 auto enc = getSparseTensorEncoding(t->get().getType()); 171 if (enc) 172 annotated = true; 173 assert(map.getNumResults() == op.getRank(t)); 174 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 175 unsigned tensor = t->getOperandNumber(); 176 AffineExpr a = map.getResult(perm(enc, d)); 177 if (!findAffine(merger, tensor, a, toDim(enc, d), !enc)) 178 return false; // inadmissable affine expression 179 } 180 } 181 return annotated; 182 } 183 184 /// A DFS helper to compute a topological sort. Note that recursion is 185 /// bounded by the number of implicit loops, which is always small. 186 /// Returns false when a cycle is detected. 187 static bool topSortDFS(unsigned i, std::vector<unsigned> &visit, 188 std::vector<unsigned> &topSort, 189 std::vector<std::vector<bool>> &adjM) { 190 if (visit[i] != 0) 191 return visit[i] != 1; // 1 denotes cycle! 192 visit[i] = 1; 193 for (unsigned j = 0, e = visit.size(); j < e; j++) 194 if (adjM[i][j]) 195 if (!topSortDFS(j, visit, topSort, adjM)) 196 return false; 197 visit[i] = 2; 198 topSort.push_back(i); 199 return true; 200 } 201 202 /// Helper method to add all constraints from the indices in one affine 203 /// expression before all indices in the other affine expression. For 204 /// example i0+i1 < i2+i3+1 yields i0<i2, i0<i3, i1<i2, and i1<i3. 205 static void addAffineOrderings(std::vector<std::vector<bool>> &adjM, 206 AffineExpr a, AffineExpr b, unsigned fidx) { 207 switch (a.getKind()) { 208 case AffineExprKind::DimId: { 209 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 210 if (b) 211 addAffineOrderings(adjM, b, AffineExpr(), idx); 212 else 213 adjM[fidx][idx] = true; 214 break; 215 } 216 case AffineExprKind::Add: 217 case AffineExprKind::Mul: { 218 auto binOp = a.cast<AffineBinaryOpExpr>(); 219 addAffineOrderings(adjM, binOp.getLHS(), b, fidx); 220 addAffineOrderings(adjM, binOp.getRHS(), b, fidx); 221 break; 222 } 223 default: 224 break; 225 } 226 } 227 228 /// Computes a topologically sorted iteration graph for the linalg operation. 229 /// Ensures all tensors are visited in natural index order. This is essential 230 /// for sparse storage formats since these only support access along fixed 231 /// dimensions. Even for dense storage formats, however, the natural index 232 /// order yields innermost unit-stride access with better spatial locality. 233 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op, 234 std::vector<unsigned> &topSort, 235 unsigned mask) { 236 // Set up an n x n from/to adjacency matrix of the iteration graph 237 // for the implicit loop indices i_0 .. i_n-1. 238 unsigned n = op.getNumLoops(); 239 std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false)); 240 241 // Iterate over the indexing maps of every tensor in the tensor expression. 242 for (OpOperand *t : op.getInputAndOutputOperands()) { 243 auto map = op.getTiedIndexingMap(t); 244 auto enc = getSparseTensorEncoding(t->get().getType()); 245 assert(map.getNumDims() == n); 246 // Skip dense tensor constraints when not requested. 247 if (!(mask & SortMask::kIncludeDense) && !enc) 248 continue; 249 // Each tensor expression and optional dimension ordering (row-major 250 // by default) puts an ordering constraint on the loop indices. For 251 // example, the tensor expresion A_ijk forces the ordering i < j < k 252 // on the loop indices if no explicit dimension ordering is given. 253 for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) { 254 AffineExpr f = map.getResult(perm(enc, d - 1)); 255 AffineExpr t = map.getResult(perm(enc, d)); 256 addAffineOrderings(adjM, f, t, 0); 257 } 258 // Push unrelated loops into sparse iteration space, so these 259 // will be skipped more often. 260 if (mask & SortMask::kIncludeUndef) { 261 unsigned tensor = t->getOperandNumber(); 262 for (unsigned i = 0; i < n; i++) 263 if (merger.isDim(tensor, i, Dim::kSparse)) 264 for (unsigned j = 0; j < n; j++) 265 if (merger.isDim(tensor, j, Dim::kUndef)) 266 adjM[i][j] = true; 267 } 268 } 269 270 // Topologically sort the iteration graph to determine loop order. 271 // Report failure for a cyclic iteration graph. 272 topSort.clear(); 273 topSort.reserve(n); 274 std::vector<unsigned> visit(n, 0); 275 for (unsigned i = 0; i < n; i++) 276 if (visit[i] == 0) 277 if (!topSortDFS(i, visit, topSort, adjM)) 278 return false; // cycle! 279 std::reverse(std::begin(topSort), std::end(topSort)); 280 return true; 281 } 282 283 /// Returns true if tensor has an in-place annotation. 284 static bool isInPlace(Value val) { 285 if (auto arg = val.dyn_cast<BlockArgument>()) 286 if (auto funcOp = dyn_cast<func::FuncOp>(arg.getOwner()->getParentOp())) 287 if (auto attr = funcOp.getArgAttrOfType<BoolAttr>( 288 arg.getArgNumber(), 289 bufferization::BufferizableOpInterface::kInplaceableAttrName)) 290 return attr.getValue(); 291 return false; 292 } 293 294 /// Returns true if tensor materializes uninitialized into the computation. 295 static bool isMaterializing(Value val) { 296 return val.getDefiningOp<linalg::InitTensorOp>() || 297 val.getDefiningOp<bufferization::AllocTensorOp>(); 298 } 299 300 /// Returns true when the tensor expression is admissable for codegen. 301 /// Since all sparse input tensors are admissable, we just need to check 302 /// whether the out tensor in the tensor expression codegen is admissable. 303 /// Sets `sparseOut` to the tensor and `outerParNest` to the outer injective 304 /// nesting depth when a "truly dynamic" sparse tensor output occurs. 305 static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op, 306 std::vector<unsigned> &topSort, unsigned exp, 307 OpOperand **sparseOut, 308 unsigned &outerParNest) { 309 OpOperand *lhs = op.getOutputOperand(0); 310 unsigned tensor = lhs->getOperandNumber(); 311 auto enc = getSparseTensorEncoding(lhs->get().getType()); 312 // An non-annotated output tensor is assumed dense, and becomes a random 313 // access n-dim memref. Admissable since insertions cannot occur. 314 if (!enc) 315 return true; 316 // An all-dense annotated "sparse" output tensor becomes a linearized random 317 // access 1-dim memref. Also admissable since insertions cannot occur. 318 bool allDense = true; 319 auto iteratorTypes = op.iterator_types().getValue(); 320 unsigned numLoops = iteratorTypes.size(); 321 for (unsigned i = 0; i < numLoops; i++) 322 if (merger.isDim(tensor, i, Dim::kSparse)) { 323 allDense = false; 324 break; 325 } 326 if (allDense) 327 return true; 328 // A tensor expression with a sparse output tensor that changes its values 329 // but not its nonzero structure, an operation called "simply dynamic" in 330 // [Bik96,Ch9], is also admissable without special codegen, provided 331 // the tensor's underlying sparse storage scheme can be modified in place. 332 if (merger.isSingleCondition(tensor, exp) && isInPlace(lhs->get())) 333 return true; 334 // Accept "truly dynamic" if the output tensor materializes uninitialized 335 // into the computation and insertions occur in lexicographic index order. 336 if (isMaterializing(lhs->get())) { 337 unsigned nest = 0; 338 for (unsigned i = 0; i < numLoops; i++) { 339 if (isReductionIterator(iteratorTypes[topSort[i]])) 340 break; // terminate at first reduction 341 nest++; 342 } 343 // Determine admissable dynamic insertion situations: 344 // (1) fully injective, since there are no reductions, 345 // (2) admissable 1-d expansion in innermost dimension. 346 if (nest >= op.getRank(lhs) - 1) { 347 *sparseOut = lhs; 348 outerParNest = nest; 349 return true; 350 } 351 } 352 return false; 353 } 354 355 //===----------------------------------------------------------------------===// 356 // Sparse compiler synthesis methods (reductions). 357 //===----------------------------------------------------------------------===// 358 359 /// Maps reduction kind to vector::CombiningKind. 360 static vector::CombiningKind getCombiningKind(Reduction kind) { 361 switch (kind) { 362 case kNoReduc: 363 break; 364 case kSum: 365 return vector::CombiningKind::ADD; 366 case kProduct: 367 return vector::CombiningKind::MUL; 368 case kAnd: 369 return vector::CombiningKind::AND; 370 case kOr: 371 return vector::CombiningKind::OR; 372 case kXor: 373 return vector::CombiningKind::XOR; 374 } 375 llvm_unreachable("unknown reduction kind"); 376 } 377 378 /// Maps operation to reduction. 379 static Reduction getReduction(Kind kind) { 380 switch (kind) { 381 case Kind::kAddF: 382 case Kind::kAddC: 383 case Kind::kAddI: 384 case Kind::kSubF: 385 case Kind::kSubC: 386 case Kind::kSubI: 387 return kSum; 388 case Kind::kMulF: 389 case Kind::kMulC: 390 case Kind::kMulI: 391 return kProduct; 392 case Kind::kAndI: 393 return kAnd; 394 case Kind::kOrI: 395 return kOr; 396 case Kind::kXorI: 397 return kXor; 398 default: 399 llvm_unreachable("unexpected reduction operator"); 400 } 401 } 402 403 /// Generates an initial value for a vector reduction, following the scheme 404 /// given in Chapter 5 of "The Software Vectorization Handbook", where the 405 /// initial scalar value is correctly embedded in the vector reduction value, 406 /// and a straightforward horizontal reduction will complete the operation. 407 static Value genVectorReducInit(CodeGen &codegen, OpBuilder &builder, 408 Location loc, VectorType vtp) { 409 Value r = codegen.redVal; 410 switch (codegen.redKind) { 411 case kNoReduc: 412 break; 413 case kSum: 414 case kXor: 415 // Initialize reduction vector to: | 0 | .. | 0 | r | 416 return builder.create<vector::InsertElementOp>( 417 loc, r, constantZero(builder, loc, vtp), 418 constantIndex(builder, loc, 0)); 419 case kProduct: 420 // Initialize reduction vector to: | 1 | .. | 1 | r | 421 return builder.create<vector::InsertElementOp>( 422 loc, r, constantOne(builder, loc, vtp), constantIndex(builder, loc, 0)); 423 case kAnd: 424 case kOr: 425 // Initialize reduction vector to: | r | .. | r | r | 426 return builder.create<vector::BroadcastOp>(loc, vtp, r); 427 } 428 llvm_unreachable("unknown reduction kind"); 429 } 430 431 /// Generates final value for a vector reduction. 432 static Value genVectorReducEnd(CodeGen &codegen, OpBuilder &builder, 433 Location loc, VectorType vtp) { 434 vector::CombiningKind kind = getCombiningKind(codegen.redKind); 435 return builder.create<vector::ReductionOp>(loc, kind, codegen.redVal); 436 } 437 438 /// Updates scalarized reduction value. 439 static void updateReduc(Merger &merger, CodeGen &codegen, Value reduc) { 440 assert(codegen.redKind != kNoReduc); 441 codegen.redVal = merger.exp(codegen.redExp).val = reduc; 442 } 443 444 //===----------------------------------------------------------------------===// 445 // Sparse compiler synthesis methods (statements and expressions). 446 //===----------------------------------------------------------------------===// 447 448 /// Generates buffer for the output tensor. Note that all sparse kernels 449 /// assume that when all elements are written to (viz. x(i) = y(i) * z(i)), 450 /// the output buffer is already initialized to all zeroes and only nonzeroes 451 /// values are computed and written out. For updates (viz. x(i) += y(i) * z(i)), 452 /// only nonzeroes values are used for the updates and no assumption on the 453 /// original contents of the output buffer is necessary. 454 static Value genOutputBuffer(CodeGen &codegen, OpBuilder &builder, 455 linalg::GenericOp op, MemRefType denseTp, 456 ArrayRef<Value> args) { 457 Location loc = op.getLoc(); 458 OpOperand *lhs = op.getOutputOperand(0); 459 Value tensor = lhs->get(); 460 bool isInit = op.isInitTensor(lhs); 461 // An output tensor that is in-place can simply materialize from the buffer 462 // of the tensor that appears in the outs() clause. For updates, this has 463 // the advantage that only the nonzero value are involved in the computation, 464 // keeping the operation O(nnz). In all other cases, we are forced to zero 465 // out the buffer to enforce the assumption above, which may negatively 466 // impact running complexity (viz. O(n^2 + nnz) vs. O(nnz) for matrices). 467 // TODO: use better analysis to avoid zeroing out the buffer? 468 if (isInPlace(tensor)) { 469 Value init = 470 builder.create<bufferization::ToMemrefOp>(loc, denseTp, tensor); 471 if (!isInit) { 472 Value zero = constantZero(builder, loc, denseTp.getElementType()); 473 builder.create<linalg::FillOp>(loc, ValueRange{zero}, ValueRange{init}); 474 } 475 return init; 476 } 477 // By default, a new buffer is allocated which is either set to zero (when 478 // no updates occur or the tensor materializes into this computation) or 479 // initialized to the value of the tensor defined in the outs() clause. 480 // This is always correct (since it enforces all assumptions above) but 481 // may negatively impact running complexity as explained above. 482 Value alloc = builder.create<memref::AllocOp>(loc, denseTp, args); 483 if (!isInit || isMaterializing(tensor)) { 484 Value zero = constantZero(builder, loc, denseTp.getElementType()); 485 builder.create<linalg::FillOp>(loc, ValueRange{zero}, ValueRange{alloc}); 486 } else { 487 Value init = 488 builder.create<bufferization::ToMemrefOp>(loc, denseTp, tensor); 489 builder.create<memref::CopyOp>(loc, init, alloc); 490 } 491 return alloc; 492 } 493 494 /// Local bufferization of all dense and sparse data structures. 495 /// This code enables testing the first prototype sparse compiler. 496 // TODO: replace this with a proliferated bufferization strategy 497 static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder, 498 linalg::GenericOp op) { 499 Location loc = op.getLoc(); 500 assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1); 501 // For every tensor, find lower and upper bound on dimensions, set the 502 // same bounds on loop indices, and obtain dense or sparse buffer(s). 503 SmallVector<Value, 4> args; 504 for (OpOperand *t : op.getInputAndOutputOperands()) { 505 unsigned tensor = t->getOperandNumber(); 506 auto shape = op.getShape(t); 507 auto map = op.getTiedIndexingMap(t); 508 auto enc = getSparseTensorEncoding(t->get().getType()); 509 // Scan all dimensions of current tensor. 510 args.clear(); 511 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 512 AffineExpr a = map.getResult(perm(enc, d)); 513 if (a.getKind() != AffineExprKind::DimId) 514 continue; // compound 515 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 516 // Handle sparse storage schemes. 517 if (merger.isDim(tensor, idx, Dim::kSparse)) { 518 auto dynShape = {ShapedType::kDynamicSize}; 519 auto ptrTp = 520 MemRefType::get(dynShape, getPointerOverheadType(builder, enc)); 521 auto indTp = 522 MemRefType::get(dynShape, getIndexOverheadType(builder, enc)); 523 Value dim = constantIndex(builder, loc, d); 524 // Generate sparse primitives to obtains pointer and indices. 525 codegen.pointers[tensor][idx] = 526 builder.create<ToPointersOp>(loc, ptrTp, t->get(), dim); 527 codegen.indices[tensor][idx] = 528 builder.create<ToIndicesOp>(loc, indTp, t->get(), dim); 529 } 530 // Find upper bound in current dimension. 531 unsigned p = perm(enc, d); 532 Value up = linalg::createOrFoldDimOp(builder, loc, t->get(), p); 533 if (ShapedType::isDynamic(shape[p])) 534 args.push_back(up); 535 assert(codegen.highs[tensor][idx] == nullptr); 536 codegen.sizes[idx] = codegen.highs[tensor][idx] = up; 537 } 538 // Perform the required bufferization. Dense inputs materialize 539 // from the input tensors. Dense outputs need special handling. 540 // Sparse inputs use sparse primitives to obtain the values. 541 // We also accept in-place all-dense annotated "sparse" outputs. 542 Type elementType = getElementTypeOrSelf(t->get().getType()); 543 if (!enc) { 544 // Non-annotated dense tensors. 545 auto denseTp = MemRefType::get(shape, elementType); 546 if (tensor < op.getNumInputs()) 547 codegen.buffers[tensor] = 548 builder.create<bufferization::ToMemrefOp>(loc, denseTp, t->get()); 549 else 550 codegen.buffers[tensor] = 551 genOutputBuffer(codegen, builder, op, denseTp, args); 552 } else if (t == codegen.sparseOut) { 553 // True sparse output needs a lexIdx array. 554 Value rank = constantIndex(builder, loc, op.getRank(t)); 555 auto dynShape = {ShapedType::kDynamicSize}; 556 auto memTp = MemRefType::get(dynShape, builder.getIndexType()); 557 codegen.lexIdx = builder.create<memref::AllocaOp>(loc, memTp, rank); 558 codegen.lexVal = builder.create<memref::AllocaOp>( 559 loc, MemRefType::get({}, elementType)); 560 } else { 561 // Annotated sparse tensors. 562 auto dynShape = {ShapedType::kDynamicSize}; 563 auto sparseTp = MemRefType::get(dynShape, elementType); 564 codegen.buffers[tensor] = 565 builder.create<ToValuesOp>(loc, sparseTp, t->get()); 566 } 567 } 568 } 569 570 /// Constructs vector type. 571 static VectorType vectorType(CodeGen &codegen, Type etp) { 572 unsigned numScalableDims = codegen.options.enableVLAVectorization; 573 return VectorType::get(codegen.curVecLength, etp, numScalableDims); 574 } 575 576 /// Constructs vector type from pointer. 577 static VectorType vectorType(CodeGen &codegen, Value ptr) { 578 return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType()); 579 } 580 581 /// Constructs vector iteration mask. 582 static Value genVectorMask(CodeGen &codegen, OpBuilder &builder, Value iv, 583 Value lo, Value hi, Value step) { 584 Location loc = iv.getLoc(); 585 VectorType mtp = vectorType(codegen, builder.getI1Type()); 586 // Special case if the vector length evenly divides the trip count (for 587 // example, "for i = 0, 128, 16"). A constant all-true mask is generated 588 // so that all subsequent masked memory operations are immediately folded 589 // into unconditional memory operations. 590 IntegerAttr loInt, hiInt, stepInt; 591 if (matchPattern(lo, m_Constant(&loInt)) && 592 matchPattern(hi, m_Constant(&hiInt)) && 593 matchPattern(step, m_Constant(&stepInt))) { 594 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) 595 return builder.create<vector::BroadcastOp>( 596 loc, mtp, constantI1(builder, loc, true)); 597 } 598 // Otherwise, generate a vector mask that avoids overrunning the upperbound 599 // during vector execution. Here we rely on subsequent loop optimizations to 600 // avoid executing the mask in all iterations, for example, by splitting the 601 // loop into an unconditional vector loop and a scalar cleanup loop. 602 auto minMap = AffineMap::get( 603 /*dimCount=*/2, /*symbolCount=*/1, 604 {builder.getAffineSymbolExpr(0), 605 builder.getAffineDimExpr(0) - builder.getAffineDimExpr(1)}, 606 builder.getContext()); 607 Value end = 608 builder.createOrFold<AffineMinOp>(loc, minMap, ValueRange{hi, iv, step}); 609 return builder.create<vector::CreateMaskOp>(loc, mtp, end); 610 } 611 612 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi]. 613 static Value genVectorLoad(CodeGen &codegen, OpBuilder &builder, Value ptr, 614 ArrayRef<Value> args) { 615 Location loc = ptr.getLoc(); 616 VectorType vtp = vectorType(codegen, ptr); 617 Value pass = constantZero(builder, loc, vtp); 618 if (args.back().getType().isa<VectorType>()) { 619 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 620 Value indexVec = args.back(); 621 scalarArgs.back() = constantIndex(builder, loc, 0); 622 return builder.create<vector::GatherOp>(loc, vtp, ptr, scalarArgs, indexVec, 623 codegen.curVecMask, pass); 624 } 625 return builder.create<vector::MaskedLoadOp>(loc, vtp, ptr, args, 626 codegen.curVecMask, pass); 627 } 628 629 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs. 630 static void genVectorStore(CodeGen &codegen, OpBuilder &builder, Value rhs, 631 Value ptr, ArrayRef<Value> args) { 632 Location loc = ptr.getLoc(); 633 if (args.back().getType().isa<VectorType>()) { 634 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 635 Value indexVec = args.back(); 636 scalarArgs.back() = constantIndex(builder, loc, 0); 637 builder.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, 638 codegen.curVecMask, rhs); 639 return; 640 } 641 builder.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask, 642 rhs); 643 } 644 645 /// Generates a vectorized invariant. Here we rely on subsequent loop 646 /// optimizations to hoist the invariant broadcast out of the vector loop. 647 static Value genVectorInvariantValue(CodeGen &codegen, OpBuilder &builder, 648 Value val) { 649 VectorType vtp = vectorType(codegen, val.getType()); 650 return builder.create<vector::BroadcastOp>(val.getLoc(), vtp, val); 651 } 652 653 /// Generates an affine expression. 654 // 655 // TODO: generalize for sparse tensor subscripts 656 // 657 static Value genAffine(CodeGen &codegen, OpBuilder &builder, AffineExpr a, 658 Location loc) { 659 switch (a.getKind()) { 660 case AffineExprKind::DimId: { 661 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 662 return codegen.loops[idx]; // universal dense index 663 } 664 case AffineExprKind::Add: { 665 auto binOp = a.cast<AffineBinaryOpExpr>(); 666 return builder.create<arith::AddIOp>( 667 loc, genAffine(codegen, builder, binOp.getLHS(), loc), 668 genAffine(codegen, builder, binOp.getRHS(), loc)); 669 } 670 case AffineExprKind::Mul: { 671 auto binOp = a.cast<AffineBinaryOpExpr>(); 672 return builder.create<arith::MulIOp>( 673 loc, genAffine(codegen, builder, binOp.getLHS(), loc), 674 genAffine(codegen, builder, binOp.getRHS(), loc)); 675 } 676 case AffineExprKind::Constant: { 677 int64_t c = a.cast<AffineConstantExpr>().getValue(); 678 return constantIndex(builder, loc, c); 679 } 680 default: 681 llvm_unreachable("unexpected affine subscript"); 682 } 683 } 684 685 /// Generates index for load/store on sparse tensor. 686 static Value genIndex(CodeGen &codegen, linalg::GenericOp op, OpOperand *t) { 687 auto map = op.getTiedIndexingMap(t); 688 auto enc = getSparseTensorEncoding(t->get().getType()); 689 AffineExpr a = map.getResult(perm(enc, map.getNumResults() - 1)); 690 assert(a.getKind() == AffineExprKind::DimId); 691 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 692 return codegen.loops[idx]; 693 } 694 695 /// Generates subscript for load/store on a dense or sparse tensor. 696 static Value genSubscript(CodeGen &codegen, OpBuilder &builder, 697 linalg::GenericOp op, OpOperand *t, 698 SmallVector<Value, 4> &args) { 699 unsigned tensor = t->getOperandNumber(); 700 auto map = op.getTiedIndexingMap(t); 701 auto enc = getSparseTensorEncoding(t->get().getType()); 702 unsigned rank = map.getNumResults(); 703 if (enc) { 704 // Note that currently, all sparse subscripts are simple. 705 // TODO: accept affine too? 706 AffineExpr a = map.getResult(perm(enc, rank - 1)); 707 assert(a.getKind() == AffineExprKind::DimId); 708 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 709 assert(codegen.pidxs[tensor][idx] != nullptr); 710 args.push_back(codegen.pidxs[tensor][idx]); // position index 711 } else { 712 for (unsigned d = 0; d < rank; d++) { 713 AffineExpr a = map.getResult(perm(enc, d)); 714 args.push_back(genAffine(codegen, builder, a, op.getLoc())); 715 } 716 } 717 return codegen.buffers[tensor]; 718 } 719 720 /// Generates insertion code to implement dynamic tensor load. 721 static Value genInsertionLoad(CodeGen &codegen, OpBuilder &builder, 722 linalg::GenericOp op, OpOperand *t) { 723 Location loc = op.getLoc(); 724 // Direct lexicographic index order, tensor loads as zero. 725 if (!codegen.expValues) { 726 Type tp = getElementTypeOrSelf(t->get().getType()); 727 return constantZero(builder, loc, tp); 728 } 729 // Load from expanded access pattern. 730 Value index = genIndex(codegen, op, t); 731 return builder.create<memref::LoadOp>(loc, codegen.expValues, index); 732 } 733 734 /// Generates insertion code to implement dynamic tensor store. 735 static void genInsertionStore(CodeGen &codegen, OpBuilder &builder, 736 linalg::GenericOp op, OpOperand *t, Value rhs) { 737 Location loc = op.getLoc(); 738 // Direct insertion in lexicographic index order. 739 if (!codegen.expValues) { 740 builder.create<memref::StoreOp>(loc, rhs, codegen.lexVal); 741 builder.create<LexInsertOp>(loc, t->get(), codegen.lexIdx, codegen.lexVal); 742 return; 743 } 744 // Generates insertion code along expanded access pattern. 745 // if (!expFilled[i]) then 746 // expFilled[i] = true 747 // expAdded[inserts++] = i 748 // endif 749 // values[i] = rhs 750 Value index = genIndex(codegen, op, t); 751 Value fval = constantI1(builder, loc, false); 752 Value tval = constantI1(builder, loc, true); 753 // If statement. 754 Value filled = builder.create<memref::LoadOp>(loc, codegen.expFilled, index); 755 Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 756 filled, fval); 757 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIndexType(), cond, 758 /*else=*/true); 759 // True branch. 760 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 761 builder.create<memref::StoreOp>(loc, tval, codegen.expFilled, index); 762 builder.create<memref::StoreOp>(loc, index, codegen.expAdded, 763 codegen.expCount); 764 Value one = constantIndex(builder, loc, 1); 765 Value add = builder.create<arith::AddIOp>(loc, codegen.expCount, one); 766 builder.create<scf::YieldOp>(loc, add); 767 // False branch. 768 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 769 builder.create<scf::YieldOp>(loc, codegen.expCount); 770 builder.setInsertionPointAfter(ifOp); 771 // Value assignment. 772 codegen.expCount = ifOp.getResult(0); 773 builder.create<memref::StoreOp>(loc, rhs, codegen.expValues, index); 774 } 775 776 /// Generates a load on a dense or sparse tensor. 777 static Value genTensorLoad(Merger &merger, CodeGen &codegen, OpBuilder &builder, 778 linalg::GenericOp op, unsigned exp) { 779 // Test if the load was hoisted to a higher loop nest. 780 Value val = merger.exp(exp).val; 781 if (val) { 782 if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>()) 783 return genVectorInvariantValue(codegen, builder, val); 784 return val; 785 } 786 // Load during insertion. 787 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 788 if (t == codegen.sparseOut) 789 return genInsertionLoad(codegen, builder, op, t); 790 // Actual load. 791 SmallVector<Value, 4> args; 792 Value ptr = genSubscript(codegen, builder, op, t, args); 793 if (codegen.curVecLength > 1) 794 return genVectorLoad(codegen, builder, ptr, args); 795 return builder.create<memref::LoadOp>(op.getLoc(), ptr, args); 796 } 797 798 /// Generates a store on a dense or sparse tensor. 799 static void genTensorStore(Merger &merger, CodeGen &codegen, OpBuilder &builder, 800 linalg::GenericOp op, unsigned exp, Value rhs) { 801 Location loc = op.getLoc(); 802 // Test if this is a scalarized reduction. 803 if (codegen.redVal) { 804 if (codegen.curVecLength > 1) 805 rhs = builder.create<arith::SelectOp>(loc, codegen.curVecMask, rhs, 806 codegen.redVal); 807 updateReduc(merger, codegen, rhs); 808 return; 809 } 810 // Store during insertion. 811 OpOperand *t = op.getOutputOperand(0); 812 if (t == codegen.sparseOut) { 813 if (!rhs) { 814 // Only unary and binary are allowed to return uninitialized rhs 815 // to indicate missing output. 816 assert(merger.exp(exp).kind == kUnary || merger.exp(exp).kind == kBinary); 817 } else { 818 genInsertionStore(codegen, builder, op, t, rhs); 819 } 820 return; 821 } 822 // Actual store. 823 SmallVector<Value, 4> args; 824 Value ptr = genSubscript(codegen, builder, op, t, args); 825 if (codegen.curVecLength > 1) 826 genVectorStore(codegen, builder, rhs, ptr, args); 827 else 828 builder.create<memref::StoreOp>(loc, rhs, ptr, args); 829 } 830 831 /// Generates a pointer/index load from the sparse storage scheme. Narrower 832 /// data types need to be zero extended before casting the value into the 833 /// index type used for looping and indexing. 834 static Value genLoad(CodeGen &codegen, OpBuilder &builder, Location loc, 835 Value ptr, Value s) { 836 // See https://llvm.org/docs/GetElementPtr.html for some background on 837 // the complications described below. 838 if (codegen.curVecLength > 1) { 839 // Since the index vector is used in a subsequent gather/scatter operations, 840 // which effectively defines an unsigned pointer + signed index, we must 841 // zero extend the vector to an index width. For 8-bit and 16-bit values, 842 // an 32-bit index width suffices. For 32-bit values, zero extending the 843 // elements into 64-bit loses some performance since the 32-bit indexed 844 // gather/scatter is more efficient than the 64-bit index variant (if the 845 // negative 32-bit index space is unused, the enableSIMDIndex32 flag can 846 // preserve this performance). For 64-bit values, there is no good way 847 // to state that the indices are unsigned, with creates the potential of 848 // incorrect address calculations in the unlikely case we need such 849 // extremely large offsets. 850 Type etp = ptr.getType().cast<MemRefType>().getElementType(); 851 Value vload = genVectorLoad(codegen, builder, ptr, {s}); 852 if (!etp.isa<IndexType>()) { 853 if (etp.getIntOrFloatBitWidth() < 32) 854 vload = builder.create<arith::ExtUIOp>( 855 loc, vectorType(codegen, builder.getI32Type()), vload); 856 else if (etp.getIntOrFloatBitWidth() < 64 && 857 !codegen.options.enableSIMDIndex32) 858 vload = builder.create<arith::ExtUIOp>( 859 loc, vectorType(codegen, builder.getI64Type()), vload); 860 } 861 return vload; 862 } 863 // For the scalar case, we simply zero extend narrower indices into 64-bit 864 // values before casting to index without a performance penalty. Here too, 865 // however, indices that already are 64-bit, in theory, cannot express the 866 // full range as explained above. 867 Value load = builder.create<memref::LoadOp>(loc, ptr, s); 868 if (!load.getType().isa<IndexType>()) { 869 if (load.getType().getIntOrFloatBitWidth() < 64) 870 load = builder.create<arith::ExtUIOp>(loc, builder.getI64Type(), load); 871 load = 872 builder.create<arith::IndexCastOp>(loc, builder.getIndexType(), load); 873 } 874 return load; 875 } 876 877 /// Generates an invariant value. 878 static Value genInvariantValue(Merger &merger, CodeGen &codegen, 879 OpBuilder &builder, unsigned exp) { 880 Value val = merger.exp(exp).val; 881 if (codegen.curVecLength > 1) 882 return genVectorInvariantValue(codegen, builder, val); 883 return val; 884 } 885 886 /// Generates an address computation "sz * p + i". 887 static Value genAddress(CodeGen &codegen, OpBuilder &builder, Location loc, 888 Value size, Value p, Value i) { 889 Value mul = builder.create<arith::MulIOp>(loc, size, p); 890 if (auto vtp = i.getType().dyn_cast<VectorType>()) { 891 Value inv = 892 builder.create<arith::IndexCastOp>(loc, vtp.getElementType(), mul); 893 mul = genVectorInvariantValue(codegen, builder, inv); 894 } 895 return builder.create<arith::AddIOp>(loc, mul, i); 896 } 897 898 /// Generates an index value. 899 static Value genIndexValue(CodeGen &codegen, OpBuilder &builder, unsigned idx, 900 unsigned ldx) { 901 Value ival = codegen.loops[idx]; 902 Type itype = ival.getType(); 903 // During vectorization, we either encounter: 904 // (1) indices already in vector form, as in ... = ind[lo:hi], good to go, or 905 // (2) single index, as in ... = i, must convert to [i, i+1, ...] for inner i. 906 unsigned vl = codegen.curVecLength; 907 if (vl > 1 && !itype.isa<VectorType>()) { 908 Location loc = ival.getLoc(); 909 VectorType vtp = vectorType(codegen, itype); 910 ival = builder.create<vector::BroadcastOp>(loc, vtp, ival); 911 if (idx == ldx) { 912 Value incr; 913 if (vtp.isScalable()) { 914 Type stepvty = vectorType(codegen, builder.getI64Type()); 915 Value stepv = builder.create<LLVM::StepVectorOp>(loc, stepvty); 916 incr = builder.create<arith::IndexCastOp>(loc, vtp, stepv); 917 } else { 918 SmallVector<APInt, 4> integers; 919 for (unsigned i = 0; i < vl; i++) 920 integers.push_back(APInt(/*width=*/64, i)); 921 auto values = DenseElementsAttr::get(vtp, integers); 922 incr = builder.create<arith::ConstantOp>(loc, vtp, values); 923 } 924 ival = builder.create<arith::AddIOp>(loc, ival, incr); 925 } 926 } 927 return ival; 928 } 929 930 /// Semi-ring branches are simply inlined by the sparse compiler. Prior 931 /// analysis has verified that all computations are "local" to the inlined 932 /// branch or otherwise invariantly defined outside the loop nest, with the 933 /// exception of index computations, which need to be relinked to actual 934 /// inlined cloned code. 935 static Value relinkBranch(CodeGen &codegen, RewriterBase &rewriter, 936 Block *block, Value e, unsigned ldx) { 937 if (Operation *def = e.getDefiningOp()) { 938 if (auto indexOp = dyn_cast<linalg::IndexOp>(def)) 939 return genIndexValue(codegen, rewriter, indexOp.dim(), ldx); 940 if (def->getBlock() == block) { 941 for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) 942 def->setOperand( 943 i, relinkBranch(codegen, rewriter, block, def->getOperand(i), ldx)); 944 } 945 } 946 return e; 947 } 948 949 /// Recursively generates tensor expression. 950 static Value genExp(Merger &merger, CodeGen &codegen, RewriterBase &rewriter, 951 linalg::GenericOp op, unsigned exp, unsigned ldx) { 952 Location loc = op.getLoc(); 953 if (exp == -1u) 954 return Value(); 955 if (merger.exp(exp).kind == Kind::kTensor) 956 return genTensorLoad(merger, codegen, rewriter, op, exp); 957 if (merger.exp(exp).kind == Kind::kInvariant) 958 return genInvariantValue(merger, codegen, rewriter, exp); 959 if (merger.exp(exp).kind == Kind::kIndex) 960 return genIndexValue(codegen, rewriter, merger.exp(exp).index, ldx); 961 Value v0 = 962 genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0, ldx); 963 Value v1 = 964 genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1, ldx); 965 Value ee = merger.buildExp(rewriter, loc, exp, v0, v1); 966 if (ee && (merger.exp(exp).kind == Kind::kUnary || 967 merger.exp(exp).kind == Kind::kBinary || 968 merger.exp(exp).kind == Kind::kBinaryBranch)) 969 ee = relinkBranch(codegen, rewriter, ee.getParentBlock(), ee, ldx); 970 return ee; 971 } 972 973 /// Determines if affine expression is invariant. 974 static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a, 975 unsigned ldx, bool &atLevel) { 976 switch (a.getKind()) { 977 case AffineExprKind::DimId: { 978 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 979 if (idx == ldx) 980 atLevel = true; 981 return codegen.loops[idx] != nullptr; // no longer in play? 982 } 983 case AffineExprKind::Add: 984 case AffineExprKind::Mul: { 985 auto binOp = a.cast<AffineBinaryOpExpr>(); 986 return isInvariantAffine(codegen, binOp.getLHS(), ldx, atLevel) && 987 isInvariantAffine(codegen, binOp.getRHS(), ldx, atLevel); 988 } 989 default: 990 return true; 991 } 992 } 993 994 /// Hoists loop invariant tensor loads for which indices have been exhausted. 995 static void genInvariants(Merger &merger, CodeGen &codegen, OpBuilder &builder, 996 linalg::GenericOp op, unsigned exp, unsigned ldx, 997 bool atStart, Kind last = Kind::kTensor) { 998 if (exp == -1u) 999 return; 1000 if (merger.exp(exp).kind == Kind::kTensor) { 1001 // Inspect tensor indices. 1002 bool atLevel = ldx == -1u; 1003 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 1004 auto map = op.getTiedIndexingMap(t); 1005 auto enc = getSparseTensorEncoding(t->get().getType()); 1006 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 1007 AffineExpr a = map.getResult(perm(enc, d)); 1008 if (!isInvariantAffine(codegen, a, ldx, atLevel)) 1009 return; // still in play 1010 } 1011 // All exhausted at this level (atLevel denotes exactly at this level). 1012 if (!atLevel) 1013 return; 1014 OpOperand *lhs = op.getOutputOperand(0); 1015 if (lhs == t) { 1016 // Start or end a scalarized reduction 1017 if (atStart) { 1018 Value load = genTensorLoad(merger, codegen, builder, op, exp); 1019 codegen.redKind = getReduction(last); 1020 codegen.redExp = exp; 1021 updateReduc(merger, codegen, load); 1022 } else { 1023 Value redVal = codegen.redVal; 1024 updateReduc(merger, codegen, Value()); 1025 codegen.redExp = -1u; 1026 codegen.redKind = kNoReduc; 1027 genTensorStore(merger, codegen, builder, op, exp, redVal); 1028 } 1029 } else { 1030 // Start or end loop invariant hoisting of a tensor load. 1031 merger.exp(exp).val = 1032 atStart ? genTensorLoad(merger, codegen, builder, op, exp) : Value(); 1033 } 1034 } else if (merger.exp(exp).kind != Kind::kInvariant && 1035 merger.exp(exp).kind != Kind::kIndex) { 1036 // Traverse into the binary operations. Note that we only hoist 1037 // tensor loads, since subsequent MLIR/LLVM passes know how to 1038 // deal with all other kinds of derived loop invariants. 1039 Kind last = merger.exp(exp).kind; 1040 unsigned e0 = merger.exp(exp).children.e0; 1041 unsigned e1 = merger.exp(exp).children.e1; 1042 genInvariants(merger, codegen, builder, op, e0, ldx, atStart, last); 1043 genInvariants(merger, codegen, builder, op, e1, ldx, atStart, last); 1044 } 1045 } 1046 1047 /// Generates an expanded access pattern in innermost dimension. 1048 static void genExpansion(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1049 linalg::GenericOp op, unsigned at, bool atStart) { 1050 OpOperand *lhs = codegen.sparseOut; 1051 if (!lhs || codegen.outerParNest != op.getRank(lhs) - 1 || 1052 at != codegen.outerParNest) 1053 return; // not needed at this level 1054 // Generate start or end of an expanded access pattern. 1055 Value tensor = lhs->get(); 1056 Location loc = op.getLoc(); 1057 if (atStart) { 1058 auto dynShape = {ShapedType::kDynamicSize}; 1059 Type etp = tensor.getType().cast<ShapedType>().getElementType(); 1060 Type t1 = MemRefType::get(dynShape, etp); 1061 Type t2 = MemRefType::get(dynShape, builder.getI1Type()); 1062 Type t3 = MemRefType::get(dynShape, builder.getIndexType()); 1063 Type t4 = builder.getIndexType(); 1064 auto res = 1065 builder.create<ExpandOp>(loc, TypeRange({t1, t2, t3, t4}), tensor); 1066 assert(res.getNumResults() == 4); 1067 assert(!codegen.expValues); 1068 codegen.expValues = res.getResult(0); 1069 codegen.expFilled = res.getResult(1); 1070 codegen.expAdded = res.getResult(2); 1071 codegen.expCount = res.getResult(3); 1072 } else { 1073 assert(codegen.expValues); 1074 builder.create<CompressOp>(loc, tensor, codegen.lexIdx, codegen.expValues, 1075 codegen.expFilled, codegen.expAdded, 1076 codegen.expCount); 1077 codegen.expValues = codegen.expFilled = codegen.expAdded = 1078 codegen.expCount = Value(); 1079 } 1080 } 1081 1082 /// Generates initialization code for the subsequent loop sequence at 1083 /// current index level. Returns true if the loop sequence needs to 1084 /// maintain the universal index. 1085 static bool genInit(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1086 linalg::GenericOp op, std::vector<unsigned> &topSort, 1087 unsigned at, BitVector &inits) { 1088 bool needsUniv = false; 1089 Location loc = op.getLoc(); 1090 unsigned idx = topSort[at]; 1091 1092 // Initialize sparse positions. 1093 for (unsigned b = 0, be = inits.size(); b < be; b++) { 1094 if (inits[b]) { 1095 unsigned tensor = merger.tensor(b); 1096 assert(idx == merger.index(b)); 1097 if (merger.isDim(b, Dim::kSparse)) { 1098 // Initialize sparse index. 1099 unsigned pat = at; 1100 for (; pat != 0; pat--) { 1101 if (codegen.pidxs[tensor][topSort[pat - 1]]) 1102 break; 1103 } 1104 Value ptr = codegen.pointers[tensor][idx]; 1105 Value one = constantIndex(builder, loc, 1); 1106 Value p0 = (pat == 0) ? constantIndex(builder, loc, 0) 1107 : codegen.pidxs[tensor][topSort[pat - 1]]; 1108 codegen.pidxs[tensor][idx] = genLoad(codegen, builder, loc, ptr, p0); 1109 Value p1 = builder.create<arith::AddIOp>(loc, p0, one); 1110 codegen.highs[tensor][idx] = genLoad(codegen, builder, loc, ptr, p1); 1111 } else { 1112 // Dense index still in play. 1113 needsUniv = true; 1114 } 1115 } 1116 } 1117 1118 // Initialize the universal dense index. 1119 codegen.loops[idx] = constantIndex(builder, loc, 0); 1120 return needsUniv; 1121 } 1122 1123 /// Returns vectorization strategy. Any implicit inner loop in the Linalg 1124 /// operation is a candidate. Whether it is actually converted to SIMD code 1125 /// depends on the requested strategy. 1126 static bool isVectorFor(CodeGen &codegen, bool isInner, bool isReduction, 1127 bool isSparse) { 1128 // Reject vectorization of sparse output, unless innermost is reduction. 1129 if (codegen.sparseOut && !isReduction) 1130 return false; 1131 // Inspect strategy. 1132 switch (codegen.options.vectorizationStrategy) { 1133 case SparseVectorizationStrategy::kNone: 1134 return false; 1135 case SparseVectorizationStrategy::kDenseInnerLoop: 1136 return isInner && !isSparse; 1137 case SparseVectorizationStrategy::kAnyStorageInnerLoop: 1138 return isInner; 1139 } 1140 llvm_unreachable("unexpected vectorization strategy"); 1141 } 1142 1143 /// Returns parallelization strategy. Any implicit loop in the Linalg operation 1144 /// that is marked "parallel" is a candidate. Whether it is actually converted 1145 /// to a parallel operation depends on the requested strategy. 1146 static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction, 1147 bool isSparse, bool isVector) { 1148 // Reject parallelization of sparse output. 1149 if (codegen.sparseOut) 1150 return false; 1151 // Inspect strategy. 1152 switch (codegen.options.parallelizationStrategy) { 1153 case SparseParallelizationStrategy::kNone: 1154 return false; 1155 case SparseParallelizationStrategy::kDenseOuterLoop: 1156 return isOuter && !isSparse && !isReduction && !isVector; 1157 case SparseParallelizationStrategy::kAnyStorageOuterLoop: 1158 return isOuter && !isReduction && !isVector; 1159 case SparseParallelizationStrategy::kDenseAnyLoop: 1160 return !isSparse && !isReduction && !isVector; 1161 case SparseParallelizationStrategy::kAnyStorageAnyLoop: 1162 return !isReduction && !isVector; 1163 } 1164 llvm_unreachable("unexpected parallelization strategy"); 1165 } 1166 1167 /// Checks unit stride for dense tensors. The iteration graph may have ignored 1168 /// dense access patterns in order to avoid cycles (sparse access patterns are 1169 /// always placed innermost), but that means dense access has become strided. 1170 /// This prevents effective vectorization. 1171 static bool denseUnitStrides(Merger &merger, linalg::GenericOp op, 1172 unsigned idx) { 1173 for (OpOperand *t : op.getInputAndOutputOperands()) { 1174 if (!getSparseTensorEncoding(t->get().getType())) { 1175 auto map = op.getTiedIndexingMap(t); 1176 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 1177 AffineExpr a = map.getResult(d); 1178 // Report non-unit stride if innermost index appears at an outer 1179 // dimension (true non-unit stride) or if the innermost index appears 1180 // in a compound subscript in the innermost dimension. Even if the 1181 // latter is unit stride, it does not play well with scatter/gather. 1182 // TODO: accept unit stride affine innermost like a[i,j+k+1]? 1183 if (a.isFunctionOfDim(idx) && 1184 ((d != rank - 1) || (a.getKind() != AffineExprKind::DimId))) 1185 return false; 1186 } 1187 } 1188 } 1189 return true; 1190 } 1191 1192 /// Generates a for-loop on a single index. 1193 static Operation *genFor(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1194 linalg::GenericOp op, bool isOuter, bool isInner, 1195 unsigned idx, BitVector &indices) { 1196 unsigned fb = indices.find_first(); 1197 unsigned tensor = merger.tensor(fb); 1198 assert(idx == merger.index(fb)); 1199 auto iteratorTypes = op.iterator_types().getValue(); 1200 bool isReduction = isReductionIterator(iteratorTypes[idx]); 1201 bool isSparse = merger.isDim(fb, Dim::kSparse); 1202 bool isVector = isVectorFor(codegen, isInner, isReduction, isSparse) && 1203 denseUnitStrides(merger, op, idx); 1204 bool isParallel = 1205 isParallelFor(codegen, isOuter, isReduction, isSparse, isVector); 1206 1207 // Prepare vector length. 1208 if (isVector) 1209 codegen.curVecLength = codegen.options.vectorLength; 1210 1211 // Loop bounds and increment. 1212 Location loc = op.getLoc(); 1213 Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx]; 1214 Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx]; 1215 Value step = constantIndex(builder, loc, codegen.curVecLength); 1216 if (isVector && codegen.options.enableVLAVectorization) { 1217 Value vscale = builder.create<vector::VectorScaleOp>( 1218 loc, IndexType::get(builder.getContext())); 1219 step = builder.create<arith::MulIOp>(loc, vscale, step); 1220 } 1221 1222 // Emit a parallel loop. 1223 if (isParallel) { 1224 assert(!isVector); 1225 scf::ParallelOp parOp = builder.create<scf::ParallelOp>(loc, lo, hi, step); 1226 if (isSparse) 1227 codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0]; 1228 else 1229 codegen.loops[idx] = parOp.getInductionVars()[0]; 1230 builder.setInsertionPointToStart(parOp.getBody()); 1231 return parOp; 1232 } 1233 1234 // Emit a sequential or vector loop. 1235 SmallVector<Value, 4> operands; 1236 if (codegen.redVal) { 1237 // In a vector loop, bring reduction into SIMD form, if not already. 1238 if (isVector && !codegen.redVal.getType().isa<VectorType>()) { 1239 VectorType vtp = vectorType(codegen, codegen.redVal.getType()); 1240 Value vred = genVectorReducInit(codegen, builder, loc, vtp); 1241 updateReduc(merger, codegen, vred); 1242 } 1243 operands.push_back(codegen.redVal); 1244 } 1245 if (codegen.expValues) 1246 operands.push_back(codegen.expCount); 1247 scf::ForOp forOp = builder.create<scf::ForOp>(loc, lo, hi, step, operands); 1248 if (codegen.redVal) 1249 updateReduc(merger, codegen, forOp.getRegionIterArgs().front()); 1250 if (codegen.expValues) 1251 codegen.expCount = forOp.getRegionIterArgs().back(); 1252 // Assign induction variable to sparse or dense index. 1253 Value iv = forOp.getInductionVar(); 1254 if (isSparse) 1255 codegen.pidxs[tensor][idx] = iv; 1256 else 1257 codegen.loops[idx] = iv; 1258 builder.setInsertionPointToStart(forOp.getBody()); 1259 // Share vector iteration mask between all subsequent loads/stores. 1260 if (isVector) 1261 codegen.curVecMask = genVectorMask(codegen, builder, iv, lo, hi, step); 1262 return forOp; 1263 } 1264 1265 /// Emit a while-loop for co-iteration over multiple indices. 1266 static Operation *genWhile(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1267 linalg::GenericOp op, unsigned idx, bool needsUniv, 1268 BitVector &indices) { 1269 SmallVector<Type, 4> types; 1270 SmallVector<Value, 4> operands; 1271 // Construct the while-loop with a parameter for each index. 1272 Type indexType = builder.getIndexType(); 1273 for (unsigned b = 0, be = indices.size(); b < be; b++) { 1274 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1275 unsigned tensor = merger.tensor(b); 1276 assert(idx == merger.index(b)); 1277 types.push_back(indexType); 1278 operands.push_back(codegen.pidxs[tensor][idx]); 1279 } 1280 } 1281 if (codegen.redVal) { 1282 types.push_back(codegen.redVal.getType()); 1283 operands.push_back(codegen.redVal); 1284 } 1285 if (codegen.expValues) { 1286 types.push_back(indexType); 1287 operands.push_back(codegen.expCount); 1288 } 1289 if (needsUniv) { 1290 types.push_back(indexType); 1291 operands.push_back(codegen.loops[idx]); 1292 } 1293 assert(types.size() == operands.size()); 1294 Location loc = op.getLoc(); 1295 scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands); 1296 1297 SmallVector<Location> locs(types.size(), loc); 1298 Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs); 1299 Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs); 1300 1301 // Build the "before" region, which effectively consists 1302 // of a conjunction of "i < upper" tests on all induction. 1303 builder.setInsertionPointToStart(&whileOp.getBefore().front()); 1304 Value cond; 1305 unsigned o = 0; 1306 for (unsigned b = 0, be = indices.size(); b < be; b++) { 1307 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1308 unsigned tensor = merger.tensor(b); 1309 assert(idx == merger.index(b)); 1310 Value op1 = before->getArgument(o); 1311 Value op2 = codegen.highs[tensor][idx]; 1312 Value opc = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 1313 op1, op2); 1314 cond = cond ? builder.create<arith::AndIOp>(loc, cond, opc) : opc; 1315 codegen.pidxs[tensor][idx] = after->getArgument(o++); 1316 } 1317 } 1318 if (codegen.redVal) 1319 updateReduc(merger, codegen, after->getArgument(o++)); 1320 if (codegen.expValues) 1321 codegen.expCount = after->getArgument(o++); 1322 if (needsUniv) 1323 codegen.loops[idx] = after->getArgument(o++); 1324 assert(o == operands.size()); 1325 builder.create<scf::ConditionOp>(loc, cond, before->getArguments()); 1326 builder.setInsertionPointToStart(&whileOp.getAfter().front()); 1327 return whileOp; 1328 } 1329 1330 /// Generates a for-loop or a while-loop, depending on whether it implements 1331 /// singleton iteration or co-iteration over the given conjunction. 1332 static Operation *genLoop(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1333 linalg::GenericOp op, std::vector<unsigned> &topSort, 1334 unsigned at, bool needsUniv, BitVector &indices) { 1335 unsigned idx = topSort[at]; 1336 if (indices.count() == 1) { 1337 bool isOuter = at == 0; 1338 bool isInner = at == topSort.size() - 1; 1339 return genFor(merger, codegen, builder, op, isOuter, isInner, idx, indices); 1340 } 1341 return genWhile(merger, codegen, builder, op, idx, needsUniv, indices); 1342 } 1343 1344 /// Generates the local variables for this loop, consisting of the sparse 1345 /// indices, restored universal dense index, and dense positions. 1346 static void genLocals(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1347 linalg::GenericOp op, std::vector<unsigned> &topSort, 1348 unsigned at, bool needsUniv, BitVector &locals) { 1349 Location loc = op.getLoc(); 1350 unsigned idx = topSort[at]; 1351 1352 // Initialize sparse indices. 1353 Value min; 1354 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1355 if (locals[b] && merger.isDim(b, Dim::kSparse)) { 1356 unsigned tensor = merger.tensor(b); 1357 assert(idx == merger.index(b)); 1358 Value ptr = codegen.indices[tensor][idx]; 1359 Value s = codegen.pidxs[tensor][idx]; 1360 Value load = genLoad(codegen, builder, loc, ptr, s); 1361 codegen.idxs[tensor][idx] = load; 1362 if (!needsUniv) { 1363 if (min) { 1364 Value cmp = builder.create<arith::CmpIOp>( 1365 loc, arith::CmpIPredicate::ult, load, min); 1366 min = builder.create<arith::SelectOp>(loc, cmp, load, min); 1367 } else { 1368 min = load; 1369 } 1370 } 1371 } 1372 } 1373 1374 // Merge dense universal index over minimum. 1375 if (min) { 1376 assert(!needsUniv); 1377 codegen.loops[idx] = min; 1378 } 1379 1380 // Initialize dense positions. Note that we generate dense indices of the 1381 // output tensor unconditionally, since they may not appear in the lattice, 1382 // but may be needed for linearized codegen. 1383 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1384 if ((locals[b] || merger.isOutTensor(b, idx)) && 1385 merger.isDim(b, Dim::kDense)) { 1386 unsigned tensor = merger.tensor(b); 1387 assert(idx == merger.index(b)); 1388 unsigned pat = at; 1389 for (; pat != 0; pat--) 1390 if (codegen.pidxs[tensor][topSort[pat - 1]]) 1391 break; 1392 Value p = (pat == 0) ? constantIndex(builder, loc, 0) 1393 : codegen.pidxs[tensor][topSort[pat - 1]]; 1394 codegen.pidxs[tensor][idx] = genAddress( 1395 codegen, builder, loc, codegen.sizes[idx], p, codegen.loops[idx]); 1396 } 1397 } 1398 1399 // Move the insertion indices in lexicographic index order. During access 1400 // pattern expansion, we can skip setting the innermost dimension. 1401 if (codegen.sparseOut && !codegen.expValues) { 1402 Value pos = constantIndex(builder, loc, at); 1403 builder.create<memref::StoreOp>(loc, codegen.loops[idx], codegen.lexIdx, 1404 pos); 1405 } 1406 } 1407 1408 /// Generates the induction structure for a while-loop. 1409 static void genWhileInduction(Merger &merger, CodeGen &codegen, 1410 OpBuilder &builder, linalg::GenericOp op, 1411 unsigned idx, bool needsUniv, 1412 BitVector &induction, scf::WhileOp whileOp) { 1413 Location loc = op.getLoc(); 1414 // Finalize each else branch of all if statements. 1415 if (codegen.redVal || codegen.expValues) { 1416 while (auto ifOp = dyn_cast_or_null<scf::IfOp>( 1417 builder.getInsertionBlock()->getParentOp())) { 1418 unsigned y = 0; 1419 SmallVector<Value, 4> yields; 1420 if (codegen.redVal) { 1421 yields.push_back(codegen.redVal); 1422 updateReduc(merger, codegen, ifOp.getResult(y++)); 1423 } 1424 if (codegen.expValues) { 1425 yields.push_back(codegen.expCount); 1426 codegen.expCount = ifOp->getResult(y++); 1427 } 1428 assert(y == yields.size()); 1429 builder.create<scf::YieldOp>(loc, yields); 1430 builder.setInsertionPointAfter(ifOp); 1431 } 1432 } 1433 builder.setInsertionPointToEnd(&whileOp.getAfter().front()); 1434 // Finalize the induction. Note that the induction could be performed 1435 // in the individual if-branches to avoid re-evaluating the conditions. 1436 // However, that would result in a rather elaborate forest of yield 1437 // instructions during code generation. Moreover, performing the induction 1438 // after the if-statements more closely resembles code generated by TACO. 1439 unsigned o = 0; 1440 SmallVector<Value, 4> operands; 1441 Value one = constantIndex(builder, loc, 1); 1442 for (unsigned b = 0, be = induction.size(); b < be; b++) { 1443 if (induction[b] && merger.isDim(b, Dim::kSparse)) { 1444 unsigned tensor = merger.tensor(b); 1445 assert(idx == merger.index(b)); 1446 Value op1 = codegen.idxs[tensor][idx]; 1447 Value op2 = codegen.loops[idx]; 1448 Value op3 = codegen.pidxs[tensor][idx]; 1449 Value cmp = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 1450 op1, op2); 1451 Value add = builder.create<arith::AddIOp>(loc, op3, one); 1452 operands.push_back(builder.create<arith::SelectOp>(loc, cmp, add, op3)); 1453 codegen.pidxs[tensor][idx] = whileOp->getResult(o++); 1454 } 1455 } 1456 if (codegen.redVal) { 1457 operands.push_back(codegen.redVal); 1458 updateReduc(merger, codegen, whileOp->getResult(o++)); 1459 } 1460 if (codegen.expValues) { 1461 operands.push_back(codegen.expCount); 1462 codegen.expCount = whileOp->getResult(o++); 1463 } 1464 if (needsUniv) { 1465 operands.push_back( 1466 builder.create<arith::AddIOp>(loc, codegen.loops[idx], one)); 1467 codegen.loops[idx] = whileOp->getResult(o++); 1468 } 1469 assert(o == operands.size()); 1470 builder.create<scf::YieldOp>(loc, operands); 1471 builder.setInsertionPointAfter(whileOp); 1472 } 1473 1474 /// Generates the induction structure for a for-loop. 1475 static void genForInduction(Merger &merger, CodeGen &codegen, 1476 OpBuilder &builder, linalg::GenericOp op, 1477 Operation *loop) { 1478 Location loc = op.getLoc(); 1479 unsigned o = 0; 1480 SmallVector<Value, 4> operands; 1481 if (codegen.redVal) { 1482 operands.push_back(codegen.redVal); 1483 updateReduc(merger, codegen, loop->getResult(o++)); 1484 } 1485 if (codegen.expValues) { 1486 operands.push_back(codegen.expCount); 1487 codegen.expCount = loop->getResult(o++); 1488 } 1489 assert(o == operands.size()); 1490 if (o > 0) 1491 builder.create<scf::YieldOp>(loc, operands); 1492 builder.setInsertionPointAfter(loop); 1493 } 1494 1495 /// Generates a single if-statement within a while-loop. 1496 static scf::IfOp genIf(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1497 linalg::GenericOp op, unsigned idx, 1498 BitVector &conditions) { 1499 Location loc = op.getLoc(); 1500 SmallVector<Type, 4> types; 1501 Value cond; 1502 for (unsigned b = 0, be = conditions.size(); b < be; b++) { 1503 if (conditions[b]) { 1504 unsigned tensor = merger.tensor(b); 1505 assert(idx == merger.index(b)); 1506 Value clause; 1507 if (merger.isDim(b, Dim::kSparse)) { 1508 Value op1 = codegen.idxs[tensor][idx]; 1509 Value op2 = codegen.loops[idx]; 1510 clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 1511 op1, op2); 1512 } else { 1513 clause = constantI1(builder, loc, true); 1514 } 1515 cond = cond ? builder.create<arith::AndIOp>(loc, cond, clause) : clause; 1516 } 1517 } 1518 if (codegen.redVal) 1519 types.push_back(codegen.redVal.getType()); 1520 if (codegen.expValues) 1521 types.push_back(builder.getIndexType()); 1522 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true); 1523 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 1524 return ifOp; 1525 } 1526 1527 /// Generates end of true branch of if-statement within a while-loop. 1528 static void endIf(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1529 linalg::GenericOp op, scf::IfOp ifOp, Operation *loop, 1530 Value redInput, Value cntInput) { 1531 SmallVector<Value, 4> operands; 1532 if (codegen.redVal) { 1533 operands.push_back(codegen.redVal); 1534 updateReduc(merger, codegen, redInput); 1535 } 1536 if (codegen.expValues) { 1537 operands.push_back(codegen.expCount); 1538 codegen.expCount = cntInput; 1539 } 1540 if (!operands.empty()) 1541 builder.create<scf::YieldOp>(op.getLoc(), operands); 1542 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 1543 } 1544 1545 //===----------------------------------------------------------------------===// 1546 // Sparse compiler synthesis methods (loop sequence). 1547 //===----------------------------------------------------------------------===// 1548 1549 /// Starts a loop sequence at given level. Returns true if 1550 /// the universal loop index must be maintained at this level. 1551 static bool startLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1552 linalg::GenericOp op, std::vector<unsigned> &topSort, 1553 unsigned exp, unsigned at, unsigned idx, unsigned ldx, 1554 unsigned lts) { 1555 assert(codegen.curVecLength == 1); 1556 assert(!codegen.loops[idx]); 1557 // Emit invariants at this loop sequence level. 1558 genInvariants(merger, codegen, builder, op, exp, ldx, /*atStart=*/true); 1559 // Emit access pattern expansion for sparse tensor output. 1560 genExpansion(merger, codegen, builder, op, at, /*atStart=*/true); 1561 // Emit further intitialization at this loop sequence level. 1562 unsigned l0 = merger.set(lts)[0]; 1563 bool needsUniv = 1564 genInit(merger, codegen, builder, op, topSort, at, merger.lat(l0).bits); 1565 // Maintain the universal index only if it is actually 1566 // consumed by a subsequent lattice point. 1567 if (needsUniv) { 1568 unsigned lsize = merger.set(lts).size(); 1569 for (unsigned i = 1; i < lsize; i++) { 1570 unsigned li = merger.set(lts)[i]; 1571 if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) 1572 return true; 1573 } 1574 } 1575 return false; 1576 } 1577 1578 /// Starts a single loop in current sequence. 1579 static Operation *startLoop(Merger &merger, CodeGen &codegen, 1580 OpBuilder &builder, linalg::GenericOp op, 1581 std::vector<unsigned> &topSort, unsigned at, 1582 unsigned li, bool needsUniv) { 1583 assert(codegen.curVecLength == 1); 1584 // Emit the for/while-loop control. 1585 Operation *loop = genLoop(merger, codegen, builder, op, topSort, at, 1586 needsUniv, merger.lat(li).simple); 1587 // Emit the locals for this loop. 1588 genLocals(merger, codegen, builder, op, topSort, at, needsUniv, 1589 merger.lat(li).bits); 1590 return loop; 1591 } 1592 1593 /// Ends a single loop in current sequence. Returns new values for needsUniv. 1594 static bool endLoop(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1595 linalg::GenericOp op, Operation *loop, unsigned idx, 1596 unsigned li, bool needsUniv) { 1597 codegen.curVecLength = 1; 1598 // End a while-loop. 1599 if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) { 1600 genWhileInduction(merger, codegen, builder, op, idx, needsUniv, 1601 merger.lat(li).bits, whileOp); 1602 return needsUniv; 1603 } 1604 // End a for-loop. 1605 genForInduction(merger, codegen, builder, op, loop); 1606 return false; 1607 } 1608 1609 /// Ends a loop sequence at given level. 1610 static void endLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder, 1611 linalg::GenericOp op, unsigned exp, unsigned at, 1612 unsigned idx, unsigned ldx) { 1613 assert(codegen.curVecLength == 1); 1614 codegen.loops[idx] = Value(); 1615 // Bring a pending reduction back from SIMD form when sequence ends. 1616 if (codegen.redVal) 1617 if (auto vtp = codegen.redVal.getType().dyn_cast<VectorType>()) 1618 updateReduc(merger, codegen, 1619 genVectorReducEnd(codegen, builder, op.getLoc(), vtp)); 1620 // Unmark bookkeeping of invariants and loop index. 1621 genInvariants(merger, codegen, builder, op, exp, ldx, /*atStart=*/false); 1622 // Finalize access pattern expansion for sparse tensor output. 1623 genExpansion(merger, codegen, builder, op, at, /*atStart=*/false); 1624 } 1625 1626 /// Recursively generates code while computing iteration lattices in order 1627 /// to manage the complexity of implementing co-iteration over unions 1628 /// and intersections of sparse iterations spaces. 1629 static void genStmt(Merger &merger, CodeGen &codegen, RewriterBase &rewriter, 1630 linalg::GenericOp op, std::vector<unsigned> &topSort, 1631 unsigned exp, unsigned at) { 1632 // At each leaf, assign remaining tensor (sub)expression to output tensor. 1633 if (at == topSort.size()) { 1634 unsigned ldx = topSort[at - 1]; 1635 Value rhs = genExp(merger, codegen, rewriter, op, exp, ldx); 1636 genTensorStore(merger, codegen, rewriter, op, exp, rhs); 1637 return; 1638 } 1639 1640 // Construct iteration lattices for current loop index, with L0 at top. 1641 unsigned idx = topSort[at]; 1642 unsigned ldx = at == 0 ? -1u : topSort[at - 1]; 1643 unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx)); 1644 1645 // Start a loop sequence. 1646 bool needsUniv = startLoopSeq(merger, codegen, rewriter, op, topSort, exp, at, 1647 idx, ldx, lts); 1648 1649 // Emit a loop for every lattice point L0 >= Li in this loop sequence. 1650 unsigned lsize = merger.set(lts).size(); 1651 for (unsigned i = 0; i < lsize; i++) { 1652 // Start a loop. 1653 unsigned li = merger.set(lts)[i]; 1654 Operation *loop = 1655 startLoop(merger, codegen, rewriter, op, topSort, at, li, needsUniv); 1656 1657 // Visit all lattices points with Li >= Lj to generate the 1658 // loop-body, possibly with if statements for coiteration. 1659 Value redInput = codegen.redVal; 1660 Value cntInput = codegen.expCount; 1661 bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr; 1662 for (unsigned j = 0; j < lsize; j++) { 1663 unsigned lj = merger.set(lts)[j]; 1664 unsigned ej = merger.lat(lj).exp; 1665 if (li == lj || merger.latGT(li, lj)) { 1666 // Recurse into body of each branch. 1667 if (isWhile) { 1668 scf::IfOp ifOp = 1669 genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple); 1670 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1671 endIf(merger, codegen, rewriter, op, ifOp, loop, redInput, cntInput); 1672 } else { 1673 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1674 } 1675 } 1676 } 1677 1678 // End a loop. 1679 needsUniv = 1680 endLoop(merger, codegen, rewriter, op, loop, idx, li, needsUniv); 1681 } 1682 1683 // End a loop sequence. 1684 endLoopSeq(merger, codegen, rewriter, op, exp, at, idx, ldx); 1685 } 1686 1687 /// Converts the result computed by the sparse kernel into the required form. 1688 static void genResult(Merger &merger, CodeGen &codegen, RewriterBase &rewriter, 1689 linalg::GenericOp op) { 1690 OpOperand *lhs = op.getOutputOperand(0); 1691 Type resType = lhs->get().getType(); 1692 if (getSparseTensorEncoding(resType)) { 1693 // The sparse tensor rematerializes from the original sparse tensor's 1694 // underlying sparse storage format. 1695 rewriter.replaceOpWithNewOp<LoadOp>(op, resType, lhs->get(), 1696 codegen.sparseOut == lhs); 1697 } else { 1698 // To rematerialize an non-annotated tensor, simply load it 1699 // from the bufferized value. 1700 Value val = codegen.buffers.back(); // value array 1701 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val); 1702 } 1703 } 1704 1705 //===----------------------------------------------------------------------===// 1706 // Sparse compiler rewriting methods. 1707 //===----------------------------------------------------------------------===// 1708 1709 namespace { 1710 1711 /// Sparse rewriting rule for generic Lingalg operation. 1712 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> { 1713 public: 1714 GenericOpSparsifier(MLIRContext *context, SparsificationOptions o) 1715 : OpRewritePattern<linalg::GenericOp>(context), options(o) {} 1716 1717 LogicalResult matchAndRewrite(linalg::GenericOp op, 1718 PatternRewriter &rewriter) const override { 1719 // Detects sparse annotations and translate the per-dimension sparsity 1720 // information for all tensors to loop indices in the kernel. 1721 assert(op.getNumOutputs() == 1); 1722 unsigned numTensors = op.getNumInputsAndOutputs(); 1723 unsigned numLoops = op.iterator_types().getValue().size(); 1724 Merger merger(numTensors, numLoops); 1725 if (!findSparseAnnotations(merger, op)) 1726 return failure(); 1727 1728 // Computes a topologically sorted iteration graph to ensure 1729 // tensors are visited in natural index order. Fails on cycles. 1730 // This assumes that higher-level passes have already put the 1731 // tensors in each tensor expression in a feasible order. 1732 std::vector<unsigned> topSort; 1733 if (!computeIterationGraph(merger, op, topSort, 1734 SortMask::kIncludeUndef | 1735 SortMask::kIncludeDense) && 1736 !computeIterationGraph(merger, op, topSort, SortMask::kIncludeUndef) && 1737 !computeIterationGraph(merger, op, topSort, SortMask::kIncludeDense) && 1738 !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly)) 1739 return failure(); 1740 1741 // Builds the tensor expression for the Linalg operation in SSA form. 1742 Optional<unsigned> optExp = merger.buildTensorExpFromLinalg(op); 1743 if (!optExp.hasValue()) 1744 return failure(); 1745 unsigned exp = optExp.getValue(); 1746 1747 // Rejects an inadmissable tensor expression. 1748 OpOperand *sparseOut = nullptr; 1749 unsigned outerParNest = 0; 1750 if (!isAdmissableTensorExp(merger, op, topSort, exp, &sparseOut, 1751 outerParNest)) 1752 return failure(); 1753 1754 // Recursively generates code. 1755 merger.setHasSparseOut(sparseOut != nullptr); 1756 CodeGen codegen(options, numTensors, numLoops, sparseOut, outerParNest); 1757 genBuffers(merger, codegen, rewriter, op); 1758 genStmt(merger, codegen, rewriter, op, topSort, exp, 0); 1759 genResult(merger, codegen, rewriter, op); 1760 return success(); 1761 } 1762 1763 private: 1764 /// Options to control sparse code generation. 1765 SparsificationOptions options; 1766 }; 1767 1768 } // namespace 1769 1770 /// Populates the given patterns list with rewriting rules required for 1771 /// the sparsification of linear algebra operations. 1772 void mlir::populateSparsificationPatterns( 1773 RewritePatternSet &patterns, const SparsificationOptions &options) { 1774 patterns.add<GenericOpSparsifier>(patterns.getContext(), options); 1775 } 1776