1 //===- Sparsification.cpp - Implementation of sparsification --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements converting sparse tensor types to actual sparse code. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 16 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" 17 #include "mlir/Dialect/Linalg/IR/Linalg.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 #include "mlir/Dialect/SCF/SCF.h" 21 #include "mlir/Dialect/SCF/Transforms.h" 22 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 23 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 24 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 25 #include "mlir/Dialect/StandardOps/IR/Ops.h" 26 #include "mlir/Dialect/Vector/VectorOps.h" 27 #include "mlir/IR/Matchers.h" 28 #include "mlir/IR/TensorEncoding.h" 29 #include "llvm/ADT/SmallBitVector.h" 30 31 using namespace mlir; 32 using namespace mlir::sparse_tensor; 33 34 //===----------------------------------------------------------------------===// 35 // Declarations of data structures. 36 //===----------------------------------------------------------------------===// 37 38 namespace { 39 40 // Iteration graph sorting. 41 enum SortMask { kSparseOnly = 0x0, kIncludeDense = 0x1, kIncludeUndef = 0x2 }; 42 43 // Reduction kinds. 44 enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor }; 45 46 // Code generation. 47 struct CodeGen { 48 CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops, 49 OpOperand *op, unsigned nest) 50 : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors), 51 pointers(numTensors, std::vector<Value>(numLoops)), 52 indices(numTensors, std::vector<Value>(numLoops)), 53 highs(numTensors, std::vector<Value>(numLoops)), 54 pidxs(numTensors, std::vector<Value>(numLoops)), 55 idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(), 56 redKind(kNoReduc), sparseOut(op), outerParNest(nest), lexIdx(), 57 expValues(), expFilled(), expAdded(), expCount(), curVecLength(1), 58 curVecMask() {} 59 /// Sparsification options. 60 SparsificationOptions options; 61 /// Universal dense indices and upper bounds (by index). The loops array 62 /// is updated with the value of the universal dense index in the current 63 /// loop. The sizes array is set once with the inferred dimension sizes. 64 std::vector<Value> loops; 65 std::vector<Value> sizes; 66 /// Buffers for storing dense and sparse numerical values (by tensor). 67 /// This array is set once during bufferization of all tensors. 68 std::vector<Value> buffers; 69 /// Sparse storage schemes (1-D): pointers and indices (by tensor and index). 70 /// This array is set once during bufferization of all sparse tensors. 71 std::vector<std::vector<Value>> pointers; 72 std::vector<std::vector<Value>> indices; 73 /// Sparse iteration information (by tensor and index). These arrays 74 /// are updated to remain current within the current loop. 75 std::vector<std::vector<Value>> highs; 76 std::vector<std::vector<Value>> pidxs; 77 std::vector<std::vector<Value>> idxs; 78 /// Current reduction, updated during code generation. When indices of a 79 /// reduction are exhausted, all inner loops can use a scalarized reduction. 80 unsigned redExp; 81 Value redVal; 82 Reduction redKind; 83 // Sparse tensor as output. Implemented either through direct injective 84 // insertion in lexicographic index order (where indices are updated 85 // in the temporary array `lexIdx`) or through access pattern expansion 86 // in the innermost loop nest (`expValues` through `expCount`). 87 OpOperand *sparseOut; 88 unsigned outerParNest; 89 Value lexIdx; 90 Value expValues; 91 Value expFilled; 92 Value expAdded; 93 Value expCount; 94 // Current vector length and mask. 95 unsigned curVecLength; 96 Value curVecMask; 97 }; 98 99 } // namespace 100 101 //===----------------------------------------------------------------------===// 102 // Sparse compiler analysis methods. 103 //===----------------------------------------------------------------------===// 104 105 /// Helper method to apply dimension ordering permutation. 106 static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d) { 107 if (enc) { 108 auto order = enc.getDimOrdering(); 109 if (order) { 110 assert(order.isPermutation()); 111 return order.getDimPosition(d); 112 } 113 } 114 return d; 115 } 116 117 /// Helper method to translate dim level type to internal representation. 118 static Dim toDim(const SparseTensorEncodingAttr &enc, unsigned d) { 119 if (enc) { 120 SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d]; 121 if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed) 122 return Dim::kSparse; 123 if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton) 124 return Dim::kSingle; 125 } 126 return Dim::kDense; 127 } 128 129 /// Helper method to inspect affine expressions. Rejects cases where the 130 /// same index is used more than once. Also rejects affine expressions 131 /// that are not a direct index for annotated tensors. 132 // TODO: accept more affine cases for sparse tensors 133 static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a, Dim dim, 134 bool isDense) { 135 switch (a.getKind()) { 136 case AffineExprKind::DimId: { 137 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 138 if (!merger.isDim(tensor, idx, Dim::kUndef)) 139 return false; // used more than once 140 merger.setDim(tensor, idx, dim); 141 return true; 142 } 143 case AffineExprKind::Add: 144 case AffineExprKind::Mul: { 145 if (!isDense) 146 return false; 147 auto binOp = a.cast<AffineBinaryOpExpr>(); 148 return findAffine(merger, tensor, binOp.getLHS(), dim, isDense) && 149 findAffine(merger, tensor, binOp.getRHS(), dim, isDense); 150 } 151 case AffineExprKind::Constant: 152 return isDense; 153 default: 154 return false; 155 } 156 } 157 158 /// Helper method to inspect sparse encodings in the tensor types. 159 /// Fills the per-dimension sparsity information for all tensors. 160 /// Returns true if the sparse annotations and affine subscript 161 /// expressions of all tensors are admissable. Returns false if 162 /// no annotations are found or inadmissable constructs occur. 163 static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) { 164 bool annotated = false; 165 for (OpOperand *t : op.getInputAndOutputOperands()) { 166 auto map = op.getTiedIndexingMap(t); 167 auto enc = getSparseTensorEncoding(t->get().getType()); 168 if (enc) 169 annotated = true; 170 assert(map.getNumResults() == op.getRank(t)); 171 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 172 unsigned tensor = t->getOperandNumber(); 173 AffineExpr a = map.getResult(perm(enc, d)); 174 if (!findAffine(merger, tensor, a, toDim(enc, d), !enc)) 175 return false; // inadmissable affine expression 176 } 177 } 178 return annotated; 179 } 180 181 /// A DFS helper to compute a topological sort. Note that recursion is 182 /// bounded by the number of implicit loops, which is always small. 183 /// Returns false when a cycle is detected. 184 static bool topSortDFS(unsigned i, std::vector<unsigned> &visit, 185 std::vector<unsigned> &topSort, 186 std::vector<std::vector<bool>> &adjM) { 187 if (visit[i] != 0) 188 return visit[i] != 1; // 1 denotes cycle! 189 visit[i] = 1; 190 for (unsigned j = 0, e = visit.size(); j < e; j++) 191 if (adjM[i][j]) 192 if (!topSortDFS(j, visit, topSort, adjM)) 193 return false; 194 visit[i] = 2; 195 topSort.push_back(i); 196 return true; 197 } 198 199 /// Helper method to add all constraints from the indices in one affine 200 /// expression before all indices in the other affine expression. For 201 /// example i0+i1 < i2+i3+1 yields i0<i2, i0<i3, i1<i2, and i1<i3. 202 static void addAffineOrderings(std::vector<std::vector<bool>> &adjM, 203 AffineExpr a, AffineExpr b, unsigned fidx) { 204 switch (a.getKind()) { 205 case AffineExprKind::DimId: { 206 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 207 if (b) 208 addAffineOrderings(adjM, b, AffineExpr(), idx); 209 else 210 adjM[fidx][idx] = true; 211 break; 212 } 213 case AffineExprKind::Add: 214 case AffineExprKind::Mul: { 215 auto binOp = a.cast<AffineBinaryOpExpr>(); 216 addAffineOrderings(adjM, binOp.getLHS(), b, fidx); 217 addAffineOrderings(adjM, binOp.getRHS(), b, fidx); 218 break; 219 } 220 default: 221 break; 222 } 223 } 224 225 /// Computes a topologically sorted iteration graph for the linalg operation. 226 /// Ensures all tensors are visited in natural index order. This is essential 227 /// for sparse storage formats since these only support access along fixed 228 /// dimensions. Even for dense storage formats, however, the natural index 229 /// order yields innermost unit-stride access with better spatial locality. 230 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op, 231 std::vector<unsigned> &topSort, 232 unsigned mask) { 233 // Set up an n x n from/to adjacency matrix of the iteration graph 234 // for the implicit loop indices i_0 .. i_n-1. 235 unsigned n = op.getNumLoops(); 236 std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false)); 237 238 // Iterate over the indexing maps of every tensor in the tensor expression. 239 for (OpOperand *t : op.getInputAndOutputOperands()) { 240 auto map = op.getTiedIndexingMap(t); 241 auto enc = getSparseTensorEncoding(t->get().getType()); 242 assert(map.getNumDims() == n); 243 // Skip dense tensor constraints when not requested. 244 if (!(mask & SortMask::kIncludeDense) && !enc) 245 continue; 246 // Each tensor expression and optional dimension ordering (row-major 247 // by default) puts an ordering constraint on the loop indices. For 248 // example, the tensor expresion A_ijk forces the ordering i < j < k 249 // on the loop indices if no explicit dimension ordering is given. 250 for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) { 251 AffineExpr f = map.getResult(perm(enc, d - 1)); 252 AffineExpr t = map.getResult(perm(enc, d)); 253 addAffineOrderings(adjM, f, t, 0); 254 } 255 // Push unrelated loops into sparse iteration space, so these 256 // will be skipped more often. 257 if (mask & SortMask::kIncludeUndef) { 258 unsigned tensor = t->getOperandNumber(); 259 for (unsigned i = 0; i < n; i++) 260 if (merger.isDim(tensor, i, Dim::kSparse)) 261 for (unsigned j = 0; j < n; j++) 262 if (merger.isDim(tensor, j, Dim::kUndef)) 263 adjM[i][j] = true; 264 } 265 } 266 267 // Topologically sort the iteration graph to determine loop order. 268 // Report failure for a cyclic iteration graph. 269 topSort.clear(); 270 topSort.reserve(n); 271 std::vector<unsigned> visit(n, 0); 272 for (unsigned i = 0; i < n; i++) 273 if (visit[i] == 0) 274 if (!topSortDFS(i, visit, topSort, adjM)) 275 return false; // cycle! 276 std::reverse(std::begin(topSort), std::end(topSort)); 277 return true; 278 } 279 280 /// Returns true if tensor has an in-place annotation. 281 static bool isInPlace(Value val) { 282 if (auto arg = val.dyn_cast<BlockArgument>()) 283 if (auto funcOp = dyn_cast<FuncOp>(arg.getOwner()->getParentOp())) 284 if (auto attr = funcOp.getArgAttrOfType<BoolAttr>( 285 arg.getArgNumber(), 286 linalg::comprehensive_bufferize::BufferizableOpInterface:: 287 kInplaceableAttrName)) 288 return attr.getValue(); 289 return false; 290 } 291 292 /// Returns true if tensor materializes uninitialized into the computation. 293 static bool isMaterializing(Value val) { 294 return val.getDefiningOp<linalg::InitTensorOp>() || 295 val.getDefiningOp<InitOp>(); 296 } 297 298 /// Returns true when the tensor expression is admissable for codegen. 299 /// Since all sparse input tensors are admissable, we just need to check 300 /// whether the out tensor in the tensor expression codegen is admissable. 301 /// Sets `sparseOut` to the tensor and `outerParNest` to the outer injective 302 /// nesting depth when a "truly dynamic" sparse tensor output occurs. 303 static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op, 304 std::vector<unsigned> &topSort, unsigned exp, 305 OpOperand **sparseOut, 306 unsigned &outerParNest) { 307 OpOperand *lhs = op.getOutputOperand(0); 308 unsigned tensor = lhs->getOperandNumber(); 309 auto enc = getSparseTensorEncoding(lhs->get().getType()); 310 // An non-annotated output tensor is assumed dense, and becomes a random 311 // access n-dim memref. Admissable since insertions cannot occur. 312 if (!enc) 313 return true; 314 // An all-dense annotated "sparse" output tensor becomes a linearized random 315 // access 1-dim memref. Also admissable since insertions cannot occur. 316 bool allDense = true; 317 auto iteratorTypes = op.iterator_types().getValue(); 318 unsigned numLoops = iteratorTypes.size(); 319 for (unsigned i = 0; i < numLoops; i++) 320 if (merger.isDim(tensor, i, Dim::kSparse)) { 321 allDense = false; 322 break; 323 } 324 if (allDense) 325 return true; 326 // A tensor expression with a sparse output tensor that changes its values 327 // but not its nonzero structure, an operation called "simply dynamic" in 328 // [Bik96,Ch9], is also admissable without special codegen, provided 329 // the tensor's underlying sparse storage scheme can be modified in place. 330 if (merger.isSingleCondition(tensor, exp) && isInPlace(lhs->get())) 331 return true; 332 // Accept "truly dynamic" if the output tensor materializes uninitialized 333 // into the computation and insertions occur in lexicographic index order. 334 if (isMaterializing(lhs->get())) { 335 unsigned nest = 0; 336 for (unsigned i = 0; i < numLoops; i++) { 337 if (isReductionIterator(iteratorTypes[topSort[i]])) 338 break; // terminate at first reduction 339 nest++; 340 } 341 // Determine admissable dynamic insertion situations: 342 // (1) fully injective, since there are no reductions, 343 // (2) admissable 1-d expansion in innermost dimension. 344 if (nest >= op.getRank(lhs) - 1) { 345 *sparseOut = lhs; 346 outerParNest = nest; 347 return true; 348 } 349 } 350 return false; 351 } 352 353 //===----------------------------------------------------------------------===// 354 // Sparse compiler synthesis methods (reductions). 355 //===----------------------------------------------------------------------===// 356 357 /// Maps reduction kind to name encoding. 358 static StringRef getReductionName(Reduction kind) { 359 switch (kind) { 360 case kNoReduc: 361 break; 362 case kSum: 363 return "add"; 364 case kProduct: 365 return "mul"; 366 case kAnd: 367 return "and"; 368 case kOr: 369 return "or"; 370 case kXor: 371 return "xor"; 372 } 373 llvm_unreachable("unknown reduction kind"); 374 } 375 376 /// Maps operation to reduction. 377 static Reduction getReduction(Kind kind) { 378 switch (kind) { 379 case Kind::kAddF: 380 case Kind::kAddI: 381 case Kind::kSubF: 382 case Kind::kSubI: 383 return kSum; 384 case Kind::kMulF: 385 case Kind::kMulI: 386 return kProduct; 387 case Kind::kAndI: 388 return kAnd; 389 case Kind::kOrI: 390 return kOr; 391 case Kind::kXorI: 392 return kXor; 393 default: 394 llvm_unreachable("unexpected reduction operator"); 395 } 396 } 397 398 /// Generates an initial value for a vector reduction, following the scheme 399 /// given in Chapter 5 of "The Software Vectorization Handbook", where the 400 /// initial scalar value is correctly embedded in the vector reduction value, 401 /// and a straightforward horizontal reduction will complete the operation. 402 static Value genVectorReducInit(CodeGen &codegen, PatternRewriter &rewriter, 403 Location loc, VectorType vtp) { 404 Value r = codegen.redVal; 405 switch (codegen.redKind) { 406 case kNoReduc: 407 break; 408 case kSum: 409 case kXor: { 410 // Initialize reduction vector to: | 0 | .. | 0 | r | 411 Attribute zero = rewriter.getZeroAttr(vtp); 412 Value vec = rewriter.create<arith::ConstantOp>(loc, vtp, zero); 413 return rewriter.create<vector::InsertElementOp>( 414 loc, r, vec, rewriter.create<arith::ConstantIndexOp>(loc, 0)); 415 } 416 case kProduct: { 417 // Initialize reduction vector to: | 1 | .. | 1 | r | 418 Type etp = vtp.getElementType(); 419 Attribute one; 420 if (etp.isa<FloatType>()) 421 one = rewriter.getFloatAttr(etp, 1.0); 422 else 423 one = rewriter.getIntegerAttr(etp, 1); 424 Value vec = rewriter.create<arith::ConstantOp>( 425 loc, vtp, DenseElementsAttr::get(vtp, one)); 426 return rewriter.create<vector::InsertElementOp>( 427 loc, r, vec, rewriter.create<arith::ConstantIndexOp>(loc, 0)); 428 } 429 case kAnd: 430 case kOr: 431 // Initialize reduction vector to: | r | .. | r | r | 432 return rewriter.create<vector::BroadcastOp>(loc, vtp, r); 433 } 434 llvm_unreachable("unknown reduction kind"); 435 } 436 437 /// Generates final value for a vector reduction. 438 static Value genVectorReducEnd(CodeGen &codegen, PatternRewriter &rewriter, 439 Location loc, VectorType vtp) { 440 StringRef name = getReductionName(codegen.redKind); 441 StringAttr kind = rewriter.getStringAttr(name); 442 return rewriter.create<vector::ReductionOp>(loc, vtp.getElementType(), kind, 443 codegen.redVal, ValueRange{}); 444 } 445 446 /// Updates scalarized reduction value. 447 static void updateReduc(Merger &merger, CodeGen &codegen, Value reduc) { 448 assert(codegen.redKind != kNoReduc); 449 codegen.redVal = merger.exp(codegen.redExp).val = reduc; 450 } 451 452 //===----------------------------------------------------------------------===// 453 // Sparse compiler synthesis methods (statements and expressions). 454 //===----------------------------------------------------------------------===// 455 456 /// Maps sparse integer option to actual integral storage type. 457 static Type genIntType(PatternRewriter &rewriter, unsigned width) { 458 if (width == 0) 459 return rewriter.getIndexType(); 460 return rewriter.getIntegerType(width); 461 } 462 463 /// Generates buffer for the output tensor. Note that all sparse kernels 464 /// assume that when all elements are written to (viz. x(i) = y(i) * z(i)), 465 /// the output buffer is already initialized to all zeroes and only nonzeroes 466 /// values are computed and written out. For updates (viz. x(i) += y(i) * z(i)), 467 /// only nonzeroes values are used for the updates and no assumption on the 468 /// original contents of the output buffer is necessary.. 469 static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter, 470 linalg::GenericOp op, MemRefType denseTp, 471 ArrayRef<Value> args) { 472 Location loc = op.getLoc(); 473 Value tensor = op.getOutputOperand(0)->get(); 474 // The output tensor simply could materialize from the buffer that will 475 // be generated for the tensor present in the outs() clause. This has 476 // the major advantage that the sparse kernel only updates the nonzero 477 // positions for the output tensor. 478 if (isInPlace(tensor)) 479 return rewriter.create<bufferization::ToMemrefOp>(loc, denseTp, tensor); 480 // By default, a new buffer is allocated which is initialized to the 481 // tensor defined in the outs() clause. This is always correct but 482 // introduces a dense initialization component that may negatively 483 // impact the running complexity of the sparse kernel. If the tensor 484 // materializes into the computation, we need to preserve the zero 485 // initialization assumption of all sparse output buffers. 486 if (isMaterializing(tensor)) { 487 Type tp = denseTp.getElementType(); 488 Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args); 489 Value zero = 490 rewriter.create<arith::ConstantOp>(loc, tp, rewriter.getZeroAttr(tp)); 491 rewriter.create<linalg::FillOp>(loc, zero, alloc); 492 return alloc; 493 } 494 Value init = rewriter.create<bufferization::ToMemrefOp>(loc, denseTp, tensor); 495 Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args); 496 rewriter.create<memref::CopyOp>(loc, init, alloc); 497 return alloc; 498 } 499 500 /// Local bufferization of all dense and sparse data structures. 501 /// This code enables testing the first prototype sparse compiler. 502 // TODO: replace this with a proliferated bufferization strategy 503 static void genBuffers(Merger &merger, CodeGen &codegen, 504 PatternRewriter &rewriter, linalg::GenericOp op) { 505 Location loc = op.getLoc(); 506 assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1); 507 // For every tensor, find lower and upper bound on dimensions, set the 508 // same bounds on loop indices, and obtain dense or sparse buffer(s). 509 SmallVector<Value, 4> args; 510 for (OpOperand *t : op.getInputAndOutputOperands()) { 511 unsigned tensor = t->getOperandNumber(); 512 auto shape = op.getShape(t); 513 auto map = op.getTiedIndexingMap(t); 514 auto enc = getSparseTensorEncoding(t->get().getType()); 515 // Scan all dimensions of current tensor. 516 args.clear(); 517 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 518 AffineExpr a = map.getResult(perm(enc, d)); 519 if (a.getKind() != AffineExprKind::DimId) 520 continue; // compound 521 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 522 // Handle sparse storage schemes. 523 if (merger.isDim(tensor, idx, Dim::kSparse)) { 524 auto dynShape = {ShapedType::kDynamicSize}; 525 auto ptrTp = MemRefType::get( 526 dynShape, genIntType(rewriter, enc.getPointerBitWidth())); 527 auto indTp = MemRefType::get( 528 dynShape, genIntType(rewriter, enc.getIndexBitWidth())); 529 Value dim = rewriter.create<arith::ConstantIndexOp>(loc, d); 530 // Generate sparse primitives to obtains pointer and indices. 531 codegen.pointers[tensor][idx] = 532 rewriter.create<ToPointersOp>(loc, ptrTp, t->get(), dim); 533 codegen.indices[tensor][idx] = 534 rewriter.create<ToIndicesOp>(loc, indTp, t->get(), dim); 535 } 536 // Find upper bound in current dimension. 537 unsigned p = perm(enc, d); 538 Value up = linalg::createOrFoldDimOp(rewriter, loc, t->get(), p); 539 if (shape[p] == MemRefType::kDynamicSize) 540 args.push_back(up); 541 assert(codegen.highs[tensor][idx] == nullptr); 542 codegen.sizes[idx] = codegen.highs[tensor][idx] = up; 543 } 544 // Perform the required bufferization. Dense inputs materialize 545 // from the input tensors. Dense outputs need special handling. 546 // Sparse inputs use sparse primitives to obtain the values. 547 // We also accept in-place all-dense annotated "sparse" outputs. 548 Type elementType = getElementTypeOrSelf(t->get().getType()); 549 if (!enc) { 550 // Non-annotated dense tensors. 551 auto denseTp = MemRefType::get(shape, elementType); 552 if (tensor < op.getNumInputs()) 553 codegen.buffers[tensor] = 554 rewriter.create<bufferization::ToMemrefOp>(loc, denseTp, t->get()); 555 else 556 codegen.buffers[tensor] = 557 genOutputBuffer(codegen, rewriter, op, denseTp, args); 558 } else if (t == codegen.sparseOut) { 559 // True sparse output needs a lexIdx array. 560 Value rank = rewriter.create<arith::ConstantIndexOp>(loc, op.getRank(t)); 561 auto dynShape = {ShapedType::kDynamicSize}; 562 auto memTp = MemRefType::get(dynShape, rewriter.getIndexType()); 563 codegen.lexIdx = rewriter.create<memref::AllocaOp>(loc, memTp, rank); 564 } else { 565 // Annotated sparse tensors. 566 auto dynShape = {ShapedType::kDynamicSize}; 567 auto sparseTp = MemRefType::get(dynShape, elementType); 568 codegen.buffers[tensor] = 569 rewriter.create<ToValuesOp>(loc, sparseTp, t->get()); 570 } 571 } 572 } 573 574 /// Constructs vector type. 575 static VectorType vectorType(CodeGen &codegen, Type etp) { 576 return VectorType::get(codegen.curVecLength, etp); 577 } 578 579 /// Constructs vector type from pointer. 580 static VectorType vectorType(CodeGen &codegen, Value ptr) { 581 return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType()); 582 } 583 584 /// Constructs vector iteration mask. 585 static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter, 586 Value iv, Value lo, Value hi, Value step) { 587 Location loc = iv.getLoc(); 588 VectorType mtp = vectorType(codegen, genIntType(rewriter, 1)); 589 // Special case if the vector length evenly divides the trip count (for 590 // example, "for i = 0, 128, 16"). A constant all-true mask is generated 591 // so that all subsequent masked memory operations are immediately folded 592 // into unconditional memory operations. 593 IntegerAttr loInt, hiInt, stepInt; 594 if (matchPattern(lo, m_Constant(&loInt)) && 595 matchPattern(hi, m_Constant(&hiInt)) && 596 matchPattern(step, m_Constant(&stepInt))) { 597 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) 598 return rewriter.create<vector::BroadcastOp>( 599 loc, mtp, rewriter.create<arith::ConstantIntOp>(loc, 1, 1)); 600 } 601 // Otherwise, generate a vector mask that avoids overrunning the upperbound 602 // during vector execution. Here we rely on subsequent loop optimizations to 603 // avoid executing the mask in all iterations, for example, by splitting the 604 // loop into an unconditional vector loop and a scalar cleanup loop. 605 auto minMap = AffineMap::get( 606 /*dimCount=*/2, /*symbolCount=*/1, 607 {rewriter.getAffineSymbolExpr(0), 608 rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)}, 609 rewriter.getContext()); 610 Value end = 611 rewriter.createOrFold<AffineMinOp>(loc, minMap, ValueRange{hi, iv, step}); 612 return rewriter.create<vector::CreateMaskOp>(loc, mtp, end); 613 } 614 615 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi]. 616 static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter, 617 Value ptr, ArrayRef<Value> args) { 618 Location loc = ptr.getLoc(); 619 VectorType vtp = vectorType(codegen, ptr); 620 Value pass = 621 rewriter.create<arith::ConstantOp>(loc, vtp, rewriter.getZeroAttr(vtp)); 622 if (args.back().getType().isa<VectorType>()) { 623 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 624 Value indexVec = args.back(); 625 scalarArgs.back() = rewriter.create<arith::ConstantIndexOp>(loc, 0); 626 return rewriter.create<vector::GatherOp>( 627 loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass); 628 } 629 return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args, 630 codegen.curVecMask, pass); 631 } 632 633 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs. 634 static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter, 635 Value rhs, Value ptr, ArrayRef<Value> args) { 636 Location loc = ptr.getLoc(); 637 if (args.back().getType().isa<VectorType>()) { 638 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 639 Value indexVec = args.back(); 640 scalarArgs.back() = rewriter.create<arith::ConstantIndexOp>(loc, 0); 641 rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, 642 codegen.curVecMask, rhs); 643 return; 644 } 645 rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask, 646 rhs); 647 } 648 649 /// Generates a vectorized invariant. Here we rely on subsequent loop 650 /// optimizations to hoist the invariant broadcast out of the vector loop. 651 static Value genVectorInvariantValue(CodeGen &codegen, 652 PatternRewriter &rewriter, Value val) { 653 VectorType vtp = vectorType(codegen, val.getType()); 654 return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val); 655 } 656 657 /// Generates an affine expression. 658 // 659 // TODO: generalize for sparse tensor subscripts 660 // 661 static Value genAffine(CodeGen &codegen, PatternRewriter &rewriter, 662 AffineExpr a, Location loc) { 663 switch (a.getKind()) { 664 case AffineExprKind::DimId: { 665 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 666 return codegen.loops[idx]; // universal dense index 667 } 668 case AffineExprKind::Add: { 669 auto binOp = a.cast<AffineBinaryOpExpr>(); 670 return rewriter.create<arith::AddIOp>( 671 loc, genAffine(codegen, rewriter, binOp.getLHS(), loc), 672 genAffine(codegen, rewriter, binOp.getRHS(), loc)); 673 } 674 case AffineExprKind::Mul: { 675 auto binOp = a.cast<AffineBinaryOpExpr>(); 676 return rewriter.create<arith::MulIOp>( 677 loc, genAffine(codegen, rewriter, binOp.getLHS(), loc), 678 genAffine(codegen, rewriter, binOp.getRHS(), loc)); 679 } 680 case AffineExprKind::Constant: { 681 int64_t c = a.cast<AffineConstantExpr>().getValue(); 682 return rewriter.create<arith::ConstantIndexOp>(loc, c); 683 } 684 default: 685 llvm_unreachable("unexpected affine subscript"); 686 } 687 } 688 689 /// Generates index for load/store on sparse tensor. 690 static Value genIndex(CodeGen &codegen, linalg::GenericOp op, OpOperand *t) { 691 auto map = op.getTiedIndexingMap(t); 692 auto enc = getSparseTensorEncoding(t->get().getType()); 693 AffineExpr a = map.getResult(perm(enc, map.getNumResults() - 1)); 694 assert(a.getKind() == AffineExprKind::DimId); 695 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 696 return codegen.loops[idx]; 697 } 698 699 /// Generates subscript for load/store on a dense or sparse tensor. 700 static Value genSubscript(CodeGen &codegen, PatternRewriter &rewriter, 701 linalg::GenericOp op, OpOperand *t, 702 SmallVector<Value, 4> &args) { 703 unsigned tensor = t->getOperandNumber(); 704 auto map = op.getTiedIndexingMap(t); 705 auto enc = getSparseTensorEncoding(t->get().getType()); 706 unsigned rank = map.getNumResults(); 707 if (enc) { 708 // Note that currently, all sparse subscripts are simple. 709 // TODO: accept affine too? 710 AffineExpr a = map.getResult(perm(enc, rank - 1)); 711 assert(a.getKind() == AffineExprKind::DimId); 712 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 713 assert(codegen.pidxs[tensor][idx] != nullptr); 714 args.push_back(codegen.pidxs[tensor][idx]); // position index 715 } else { 716 for (unsigned d = 0; d < rank; d++) { 717 AffineExpr a = map.getResult(perm(enc, d)); 718 args.push_back(genAffine(codegen, rewriter, a, op.getLoc())); 719 } 720 } 721 return codegen.buffers[tensor]; 722 } 723 724 /// Generates insertion code to implement dynamic tensor load. 725 static Value genInsertionLoad(CodeGen &codegen, PatternRewriter &rewriter, 726 linalg::GenericOp op, OpOperand *t) { 727 Location loc = op.getLoc(); 728 // Direct lexicographic index order, tensor loads as zero. 729 if (!codegen.expValues) { 730 Type tp = getElementTypeOrSelf(t->get().getType()); 731 return rewriter.create<arith::ConstantOp>(loc, tp, 732 rewriter.getZeroAttr(tp)); 733 } 734 // Load from expanded access pattern. 735 Value index = genIndex(codegen, op, t); 736 return rewriter.create<memref::LoadOp>(loc, codegen.expValues, index); 737 } 738 739 /// Generates insertion code to implement dynamic tensor store. 740 static void genInsertionStore(CodeGen &codegen, PatternRewriter &rewriter, 741 linalg::GenericOp op, OpOperand *t, Value rhs) { 742 Location loc = op.getLoc(); 743 // Direct insertion in lexicographic index order. 744 if (!codegen.expValues) { 745 rewriter.create<LexInsertOp>(loc, t->get(), codegen.lexIdx, rhs); 746 return; 747 } 748 // Generates insertion code along expanded access pattern. 749 // if (!expFilled[i]) then 750 // expFilled[i] = true 751 // expAdded[inserts++] = i 752 // endif 753 // values[i] = rhs 754 Value index = genIndex(codegen, op, t); 755 Value fval = rewriter.create<arith::ConstantIntOp>(loc, 0, 1); // false 756 Value tval = rewriter.create<arith::ConstantIntOp>(loc, 1, 1); // true 757 // If statement. 758 Value filled = rewriter.create<memref::LoadOp>(loc, codegen.expFilled, index); 759 Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 760 filled, fval); 761 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, rewriter.getIndexType(), 762 cond, /*else=*/true); 763 // True branch. 764 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); 765 rewriter.create<memref::StoreOp>(loc, tval, codegen.expFilled, index); 766 rewriter.create<memref::StoreOp>(loc, index, codegen.expAdded, 767 codegen.expCount); 768 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 769 Value add = rewriter.create<arith::AddIOp>(loc, codegen.expCount, one); 770 rewriter.create<scf::YieldOp>(loc, add); 771 // False branch. 772 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); 773 rewriter.create<scf::YieldOp>(loc, codegen.expCount); 774 rewriter.setInsertionPointAfter(ifOp); 775 // Value assignment. 776 codegen.expCount = ifOp.getResult(0); 777 rewriter.create<memref::StoreOp>(loc, rhs, codegen.expValues, index); 778 } 779 780 /// Generates a load on a dense or sparse tensor. 781 static Value genTensorLoad(Merger &merger, CodeGen &codegen, 782 PatternRewriter &rewriter, linalg::GenericOp op, 783 unsigned exp) { 784 // Test if the load was hoisted to a higher loop nest. 785 Value val = merger.exp(exp).val; 786 if (val) { 787 if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>()) 788 return genVectorInvariantValue(codegen, rewriter, val); 789 return val; 790 } 791 // Load during insertion. 792 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 793 if (t == codegen.sparseOut) 794 return genInsertionLoad(codegen, rewriter, op, t); 795 // Actual load. 796 SmallVector<Value, 4> args; 797 Value ptr = genSubscript(codegen, rewriter, op, t, args); 798 if (codegen.curVecLength > 1) 799 return genVectorLoad(codegen, rewriter, ptr, args); 800 return rewriter.create<memref::LoadOp>(op.getLoc(), ptr, args); 801 } 802 803 /// Generates a store on a dense or sparse tensor. 804 static void genTensorStore(Merger &merger, CodeGen &codegen, 805 PatternRewriter &rewriter, linalg::GenericOp op, 806 Value rhs) { 807 Location loc = op.getLoc(); 808 // Test if this is a scalarized reduction. 809 if (codegen.redVal) { 810 if (codegen.curVecLength > 1) 811 rhs = rewriter.create<SelectOp>(loc, codegen.curVecMask, rhs, 812 codegen.redVal); 813 updateReduc(merger, codegen, rhs); 814 return; 815 } 816 // Store during insertion. 817 OpOperand *t = op.getOutputOperand(0); 818 if (t == codegen.sparseOut) { 819 genInsertionStore(codegen, rewriter, op, t, rhs); 820 return; 821 } 822 // Actual store. 823 SmallVector<Value, 4> args; 824 Value ptr = genSubscript(codegen, rewriter, op, t, args); 825 if (codegen.curVecLength > 1) 826 genVectorStore(codegen, rewriter, rhs, ptr, args); 827 else 828 rewriter.create<memref::StoreOp>(loc, rhs, ptr, args); 829 } 830 831 /// Generates a pointer/index load from the sparse storage scheme. Narrower 832 /// data types need to be zero extended before casting the value into the 833 /// index type used for looping and indexing. 834 static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc, 835 Value ptr, Value s) { 836 // See https://llvm.org/docs/GetElementPtr.html for some background on 837 // the complications described below. 838 if (codegen.curVecLength > 1) { 839 // Since the index vector is used in a subsequent gather/scatter operations, 840 // which effectively defines an unsigned pointer + signed index, we must 841 // zero extend the vector to an index width. For 8-bit and 16-bit values, 842 // an 32-bit index width suffices. For 32-bit values, zero extending the 843 // elements into 64-bit loses some performance since the 32-bit indexed 844 // gather/scatter is more efficient than the 64-bit index variant (if the 845 // negative 32-bit index space is unused, the enableSIMDIndex32 flag can 846 // preserve this performance). For 64-bit values, there is no good way 847 // to state that the indices are unsigned, with creates the potential of 848 // incorrect address calculations in the unlikely case we need such 849 // extremely large offsets. 850 Type etp = ptr.getType().cast<MemRefType>().getElementType(); 851 Value vload = genVectorLoad(codegen, rewriter, ptr, {s}); 852 if (!etp.isa<IndexType>()) { 853 if (etp.getIntOrFloatBitWidth() < 32) 854 vload = rewriter.create<arith::ExtUIOp>( 855 loc, vload, vectorType(codegen, genIntType(rewriter, 32))); 856 else if (etp.getIntOrFloatBitWidth() < 64 && 857 !codegen.options.enableSIMDIndex32) 858 vload = rewriter.create<arith::ExtUIOp>( 859 loc, vload, vectorType(codegen, genIntType(rewriter, 64))); 860 } 861 return vload; 862 } 863 // For the scalar case, we simply zero extend narrower indices into 64-bit 864 // values before casting to index without a performance penalty. Here too, 865 // however, indices that already are 64-bit, in theory, cannot express the 866 // full range as explained above. 867 Value load = rewriter.create<memref::LoadOp>(loc, ptr, s); 868 if (!load.getType().isa<IndexType>()) { 869 if (load.getType().getIntOrFloatBitWidth() < 64) 870 load = 871 rewriter.create<arith::ExtUIOp>(loc, load, genIntType(rewriter, 64)); 872 load = 873 rewriter.create<arith::IndexCastOp>(loc, load, rewriter.getIndexType()); 874 } 875 return load; 876 } 877 878 /// Generates an invariant value. 879 static Value genInvariantValue(Merger &merger, CodeGen &codegen, 880 PatternRewriter &rewriter, unsigned exp) { 881 Value val = merger.exp(exp).val; 882 if (codegen.curVecLength > 1) 883 return genVectorInvariantValue(codegen, rewriter, val); 884 return val; 885 } 886 887 /// Generates an address computation "sz * p + i". 888 static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter, 889 Location loc, Value size, Value p, Value i) { 890 Value mul = rewriter.create<arith::MulIOp>(loc, size, p); 891 if (auto vtp = i.getType().dyn_cast<VectorType>()) { 892 Value inv = 893 rewriter.create<arith::IndexCastOp>(loc, mul, vtp.getElementType()); 894 mul = genVectorInvariantValue(codegen, rewriter, inv); 895 } 896 return rewriter.create<arith::AddIOp>(loc, mul, i); 897 } 898 899 /// Recursively generates tensor expression. 900 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 901 linalg::GenericOp op, unsigned exp) { 902 Location loc = op.getLoc(); 903 if (exp == -1u) 904 return Value(); 905 if (merger.exp(exp).kind == Kind::kTensor) 906 return genTensorLoad(merger, codegen, rewriter, op, exp); 907 if (merger.exp(exp).kind == Kind::kInvariant) 908 return genInvariantValue(merger, codegen, rewriter, exp); 909 Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0); 910 Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1); 911 return merger.buildExp(rewriter, loc, exp, v0, v1); 912 } 913 914 /// Determines if affine expression is invariant. 915 static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a, 916 unsigned ldx, bool &atLevel) { 917 switch (a.getKind()) { 918 case AffineExprKind::DimId: { 919 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 920 if (idx == ldx) 921 atLevel = true; 922 return codegen.loops[idx] != nullptr; // no longer in play? 923 } 924 case AffineExprKind::Add: 925 case AffineExprKind::Mul: { 926 auto binOp = a.cast<AffineBinaryOpExpr>(); 927 return isInvariantAffine(codegen, binOp.getLHS(), ldx, atLevel) && 928 isInvariantAffine(codegen, binOp.getRHS(), ldx, atLevel); 929 } 930 default: 931 return true; 932 } 933 } 934 935 /// Hoists loop invariant tensor loads for which indices have been exhausted. 936 static void genInvariants(Merger &merger, CodeGen &codegen, 937 PatternRewriter &rewriter, linalg::GenericOp op, 938 unsigned exp, unsigned ldx, bool atStart, 939 Kind last = Kind::kTensor) { 940 if (exp == -1u) 941 return; 942 if (merger.exp(exp).kind == Kind::kTensor) { 943 // Inspect tensor indices. 944 bool atLevel = ldx == -1u; 945 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 946 auto map = op.getTiedIndexingMap(t); 947 auto enc = getSparseTensorEncoding(t->get().getType()); 948 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 949 AffineExpr a = map.getResult(perm(enc, d)); 950 if (!isInvariantAffine(codegen, a, ldx, atLevel)) 951 return; // still in play 952 } 953 // All exhausted at this level (atLevel denotes exactly at this level). 954 if (!atLevel) 955 return; 956 OpOperand *lhs = op.getOutputOperand(0); 957 if (lhs == t) { 958 // Start or end a scalarized reduction 959 if (atStart) { 960 Value load = genTensorLoad(merger, codegen, rewriter, op, exp); 961 codegen.redKind = getReduction(last); 962 codegen.redExp = exp; 963 updateReduc(merger, codegen, load); 964 } else { 965 Value redVal = codegen.redVal; 966 updateReduc(merger, codegen, Value()); 967 codegen.redExp = -1u; 968 codegen.redKind = kNoReduc; 969 genTensorStore(merger, codegen, rewriter, op, redVal); 970 } 971 } else { 972 // Start or end loop invariant hoisting of a tensor load. 973 merger.exp(exp).val = 974 atStart ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value(); 975 } 976 } else if (merger.exp(exp).kind != Kind::kInvariant) { 977 // Traverse into the binary operations. Note that we only hoist 978 // tensor loads, since subsequent MLIR/LLVM passes know how to 979 // deal with all other kinds of derived loop invariants. 980 Kind last = merger.exp(exp).kind; 981 unsigned e0 = merger.exp(exp).children.e0; 982 unsigned e1 = merger.exp(exp).children.e1; 983 genInvariants(merger, codegen, rewriter, op, e0, ldx, atStart, last); 984 genInvariants(merger, codegen, rewriter, op, e1, ldx, atStart, last); 985 } 986 } 987 988 /// Generates an expanded access pattern in innermost dimension. 989 static void genExpansion(Merger &merger, CodeGen &codegen, 990 PatternRewriter &rewriter, linalg::GenericOp op, 991 unsigned at, bool atStart) { 992 OpOperand *lhs = codegen.sparseOut; 993 if (!lhs || codegen.outerParNest != op.getRank(lhs) - 1 || 994 at != codegen.outerParNest) 995 return; // not needed at this level 996 // Generate start or end of an expanded access pattern. 997 Value tensor = lhs->get(); 998 Location loc = op.getLoc(); 999 if (atStart) { 1000 auto dynShape = {ShapedType::kDynamicSize}; 1001 Type etp = tensor.getType().cast<ShapedType>().getElementType(); 1002 Type t1 = MemRefType::get(dynShape, etp); 1003 Type t2 = MemRefType::get(dynShape, genIntType(rewriter, 1)); 1004 Type t3 = MemRefType::get(dynShape, genIntType(rewriter, 0)); 1005 Type t4 = rewriter.getIndexType(); 1006 auto res = 1007 rewriter.create<ExpandOp>(loc, TypeRange({t1, t2, t3, t4}), tensor); 1008 assert(res.getNumResults() == 4); 1009 assert(!codegen.expValues); 1010 codegen.expValues = res.getResult(0); 1011 codegen.expFilled = res.getResult(1); 1012 codegen.expAdded = res.getResult(2); 1013 codegen.expCount = res.getResult(3); 1014 } else { 1015 assert(codegen.expValues); 1016 rewriter.create<CompressOp>(loc, tensor, codegen.lexIdx, codegen.expValues, 1017 codegen.expFilled, codegen.expAdded, 1018 codegen.expCount); 1019 codegen.expValues = codegen.expFilled = codegen.expAdded = 1020 codegen.expCount = Value(); 1021 } 1022 } 1023 1024 /// Generates initialization code for the subsequent loop sequence at 1025 /// current index level. Returns true if the loop sequence needs to 1026 /// maintain the universal index. 1027 static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1028 linalg::GenericOp op, std::vector<unsigned> &topSort, 1029 unsigned at, llvm::BitVector &inits) { 1030 bool needsUniv = false; 1031 Location loc = op.getLoc(); 1032 unsigned idx = topSort[at]; 1033 1034 // Initialize sparse positions. 1035 for (unsigned b = 0, be = inits.size(); b < be; b++) { 1036 if (inits[b]) { 1037 unsigned tensor = merger.tensor(b); 1038 assert(idx == merger.index(b)); 1039 if (merger.isDim(b, Dim::kSparse)) { 1040 // Initialize sparse index. 1041 unsigned pat = at; 1042 for (; pat != 0; pat--) { 1043 if (codegen.pidxs[tensor][topSort[pat - 1]]) 1044 break; 1045 } 1046 Value ptr = codegen.pointers[tensor][idx]; 1047 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 1048 Value p0 = (pat == 0) ? rewriter.create<arith::ConstantIndexOp>(loc, 0) 1049 : codegen.pidxs[tensor][topSort[pat - 1]]; 1050 codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0); 1051 Value p1 = rewriter.create<arith::AddIOp>(loc, p0, one); 1052 codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1); 1053 } else { 1054 // Dense index still in play. 1055 needsUniv = true; 1056 } 1057 } 1058 } 1059 1060 // Initialize the universal dense index. 1061 codegen.loops[idx] = rewriter.create<arith::ConstantIndexOp>(loc, 0); 1062 return needsUniv; 1063 } 1064 1065 /// Returns vectorization strategy. Any implicit inner loop in the Linalg 1066 /// operation is a candidate. Whether it is actually converted to SIMD code 1067 /// depends on the requested strategy. 1068 static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) { 1069 switch (codegen.options.vectorizationStrategy) { 1070 case SparseVectorizationStrategy::kNone: 1071 return false; 1072 case SparseVectorizationStrategy::kDenseInnerLoop: 1073 return isInner && !isSparse; 1074 case SparseVectorizationStrategy::kAnyStorageInnerLoop: 1075 return isInner; 1076 } 1077 llvm_unreachable("unexpected vectorization strategy"); 1078 } 1079 1080 /// Returns parallelization strategy. Any implicit loop in the Linalg operation 1081 /// that is marked "parallel" is a candidate. Whether it is actually converted 1082 /// to a parallel operation depends on the requested strategy. 1083 static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction, 1084 bool isSparse, bool isVector) { 1085 switch (codegen.options.parallelizationStrategy) { 1086 case SparseParallelizationStrategy::kNone: 1087 return false; 1088 case SparseParallelizationStrategy::kDenseOuterLoop: 1089 return isOuter && !isSparse && !isReduction && !isVector; 1090 case SparseParallelizationStrategy::kAnyStorageOuterLoop: 1091 return isOuter && !isReduction && !isVector; 1092 case SparseParallelizationStrategy::kDenseAnyLoop: 1093 return !isSparse && !isReduction && !isVector; 1094 case SparseParallelizationStrategy::kAnyStorageAnyLoop: 1095 return !isReduction && !isVector; 1096 } 1097 llvm_unreachable("unexpected parallelization strategy"); 1098 } 1099 1100 /// Checks unit stride for dense tensors. The iteration graph may have ignored 1101 /// dense access patterns in order to avoid cycles (sparse access patterns are 1102 /// always placed innermost), but that means dense access has become strided. 1103 /// This prevents effective vectorization. 1104 static bool denseUnitStrides(Merger &merger, linalg::GenericOp op, 1105 unsigned idx) { 1106 for (OpOperand *t : op.getInputAndOutputOperands()) { 1107 if (!getSparseTensorEncoding(t->get().getType())) { 1108 auto map = op.getTiedIndexingMap(t); 1109 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 1110 AffineExpr a = map.getResult(d); 1111 // Report non-unit stride if innermost index appears at an outer 1112 // dimension (true non-unit stride) or if the innermost index appears 1113 // in a compound subscript in the innermost dimension. Even if the 1114 // latter is unit stride, it does not play well with scatter/gather. 1115 // TODO: accept unit stride affine innermost like a[i,j+k+1]? 1116 if (a.isFunctionOfDim(idx) && 1117 ((d != rank - 1) || (a.getKind() != AffineExprKind::DimId))) 1118 return false; 1119 } 1120 } 1121 } 1122 return true; 1123 } 1124 1125 /// Generates a for-loop on a single index. 1126 static Operation *genFor(Merger &merger, CodeGen &codegen, 1127 PatternRewriter &rewriter, linalg::GenericOp op, 1128 bool isOuter, bool isInner, unsigned idx, 1129 llvm::BitVector &indices) { 1130 unsigned fb = indices.find_first(); 1131 unsigned tensor = merger.tensor(fb); 1132 assert(idx == merger.index(fb)); 1133 auto iteratorTypes = op.iterator_types().getValue(); 1134 bool isReduction = isReductionIterator(iteratorTypes[idx]); 1135 bool isSparse = merger.isDim(fb, Dim::kSparse); 1136 bool isVector = !codegen.sparseOut && 1137 isVectorFor(codegen, isInner, isSparse) && 1138 denseUnitStrides(merger, op, idx); 1139 bool isParallel = 1140 !codegen.sparseOut && 1141 isParallelFor(codegen, isOuter, isReduction, isSparse, isVector); 1142 1143 // Prepare vector length. 1144 if (isVector) 1145 codegen.curVecLength = codegen.options.vectorLength; 1146 1147 // Loop bounds and increment. 1148 Location loc = op.getLoc(); 1149 Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx]; 1150 Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx]; 1151 Value step = 1152 rewriter.create<arith::ConstantIndexOp>(loc, codegen.curVecLength); 1153 1154 // Emit a parallel loop. 1155 if (isParallel) { 1156 assert(!isVector); 1157 scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step); 1158 if (isSparse) 1159 codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0]; 1160 else 1161 codegen.loops[idx] = parOp.getInductionVars()[0]; 1162 rewriter.setInsertionPointToStart(parOp.getBody()); 1163 return parOp; 1164 } 1165 1166 // Emit a sequential or vector loop. 1167 SmallVector<Value, 4> operands; 1168 if (codegen.redVal) { 1169 // In a vector loop, bring reduction into SIMD form, if not already. 1170 if (isVector && !codegen.redVal.getType().isa<VectorType>()) { 1171 VectorType vtp = vectorType(codegen, codegen.redVal.getType()); 1172 Value vred = genVectorReducInit(codegen, rewriter, loc, vtp); 1173 updateReduc(merger, codegen, vred); 1174 } 1175 operands.push_back(codegen.redVal); 1176 } 1177 if (codegen.expValues) 1178 operands.push_back(codegen.expCount); 1179 scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands); 1180 if (codegen.redVal) 1181 updateReduc(merger, codegen, forOp.getRegionIterArgs().front()); 1182 if (codegen.expValues) 1183 codegen.expCount = forOp.getRegionIterArgs().back(); 1184 // Assign induction variable to sparse or dense index. 1185 Value iv = forOp.getInductionVar(); 1186 if (isSparse) 1187 codegen.pidxs[tensor][idx] = iv; 1188 else 1189 codegen.loops[idx] = iv; 1190 rewriter.setInsertionPointToStart(forOp.getBody()); 1191 // Share vector iteration mask between all subsequent loads/stores. 1192 if (isVector) 1193 codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step); 1194 return forOp; 1195 } 1196 1197 /// Emit a while-loop for co-iteration over multiple indices. 1198 static Operation *genWhile(Merger &merger, CodeGen &codegen, 1199 PatternRewriter &rewriter, linalg::GenericOp op, 1200 unsigned idx, bool needsUniv, 1201 llvm::BitVector &indices) { 1202 SmallVector<Type, 4> types; 1203 SmallVector<Value, 4> operands; 1204 // Construct the while-loop with a parameter for each index. 1205 Type indexType = rewriter.getIndexType(); 1206 for (unsigned b = 0, be = indices.size(); b < be; b++) { 1207 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1208 unsigned tensor = merger.tensor(b); 1209 assert(idx == merger.index(b)); 1210 types.push_back(indexType); 1211 operands.push_back(codegen.pidxs[tensor][idx]); 1212 } 1213 } 1214 if (codegen.redVal) { 1215 types.push_back(codegen.redVal.getType()); 1216 operands.push_back(codegen.redVal); 1217 } 1218 if (codegen.expValues) { 1219 types.push_back(indexType); 1220 operands.push_back(codegen.expCount); 1221 } 1222 if (needsUniv) { 1223 types.push_back(indexType); 1224 operands.push_back(codegen.loops[idx]); 1225 } 1226 assert(types.size() == operands.size()); 1227 Location loc = op.getLoc(); 1228 scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands); 1229 Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types); 1230 Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, types); 1231 1232 // Build the "before" region, which effectively consists 1233 // of a conjunction of "i < upper" tests on all induction. 1234 rewriter.setInsertionPointToStart(&whileOp.getBefore().front()); 1235 Value cond; 1236 unsigned o = 0; 1237 for (unsigned b = 0, be = indices.size(); b < be; b++) { 1238 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1239 unsigned tensor = merger.tensor(b); 1240 assert(idx == merger.index(b)); 1241 Value op1 = before->getArgument(o); 1242 Value op2 = codegen.highs[tensor][idx]; 1243 Value opc = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 1244 op1, op2); 1245 cond = cond ? rewriter.create<arith::AndIOp>(loc, cond, opc) : opc; 1246 codegen.pidxs[tensor][idx] = after->getArgument(o++); 1247 } 1248 } 1249 if (codegen.redVal) 1250 updateReduc(merger, codegen, after->getArgument(o++)); 1251 if (codegen.expValues) 1252 codegen.expCount = after->getArgument(o++); 1253 if (needsUniv) 1254 codegen.loops[idx] = after->getArgument(o++); 1255 assert(o == operands.size()); 1256 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 1257 rewriter.setInsertionPointToStart(&whileOp.getAfter().front()); 1258 return whileOp; 1259 } 1260 1261 /// Generates a for-loop or a while-loop, depending on whether it implements 1262 /// singleton iteration or co-iteration over the given conjunction. 1263 static Operation *genLoop(Merger &merger, CodeGen &codegen, 1264 PatternRewriter &rewriter, linalg::GenericOp op, 1265 std::vector<unsigned> &topSort, unsigned at, 1266 bool needsUniv, llvm::BitVector &indices) { 1267 unsigned idx = topSort[at]; 1268 if (indices.count() == 1) { 1269 bool isOuter = at == 0; 1270 bool isInner = at == topSort.size() - 1; 1271 return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx, 1272 indices); 1273 } 1274 return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices); 1275 } 1276 1277 /// Generates the local variables for this loop, consisting of the sparse 1278 /// indices, restored universal dense index, and dense positions. 1279 static void genLocals(Merger &merger, CodeGen &codegen, 1280 PatternRewriter &rewriter, linalg::GenericOp op, 1281 std::vector<unsigned> &topSort, unsigned at, 1282 bool needsUniv, llvm::BitVector &locals) { 1283 Location loc = op.getLoc(); 1284 unsigned idx = topSort[at]; 1285 1286 // Initialize sparse indices. 1287 Value min; 1288 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1289 if (locals[b] && merger.isDim(b, Dim::kSparse)) { 1290 unsigned tensor = merger.tensor(b); 1291 assert(idx == merger.index(b)); 1292 Value ptr = codegen.indices[tensor][idx]; 1293 Value s = codegen.pidxs[tensor][idx]; 1294 Value load = genLoad(codegen, rewriter, loc, ptr, s); 1295 codegen.idxs[tensor][idx] = load; 1296 if (!needsUniv) { 1297 if (min) { 1298 Value cmp = rewriter.create<arith::CmpIOp>( 1299 loc, arith::CmpIPredicate::ult, load, min); 1300 min = rewriter.create<SelectOp>(loc, cmp, load, min); 1301 } else { 1302 min = load; 1303 } 1304 } 1305 } 1306 } 1307 1308 // Merge dense universal index over minimum. 1309 if (min) { 1310 assert(!needsUniv); 1311 codegen.loops[idx] = min; 1312 } 1313 1314 // Initialize dense positions. Note that we generate dense indices of the 1315 // output tensor unconditionally, since they may not appear in the lattice, 1316 // but may be needed for linearized codegen. 1317 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1318 if ((locals[b] || merger.isOutTensor(b, idx)) && 1319 merger.isDim(b, Dim::kDense)) { 1320 unsigned tensor = merger.tensor(b); 1321 assert(idx == merger.index(b)); 1322 unsigned pat = at; 1323 for (; pat != 0; pat--) 1324 if (codegen.pidxs[tensor][topSort[pat - 1]]) 1325 break; 1326 Value p = (pat == 0) ? rewriter.create<arith::ConstantIndexOp>(loc, 0) 1327 : codegen.pidxs[tensor][topSort[pat - 1]]; 1328 codegen.pidxs[tensor][idx] = genAddress( 1329 codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]); 1330 } 1331 } 1332 1333 // Move the insertion indices in lexicographic index order. During access 1334 // pattern expansion, we can skip setting the innermost dimension. 1335 if (codegen.sparseOut && !codegen.expValues) { 1336 Value pos = rewriter.create<arith::ConstantIndexOp>(loc, at); 1337 rewriter.create<memref::StoreOp>(loc, codegen.loops[idx], codegen.lexIdx, 1338 pos); 1339 } 1340 } 1341 1342 /// Generates the induction structure for a while-loop. 1343 static void genWhileInduction(Merger &merger, CodeGen &codegen, 1344 PatternRewriter &rewriter, linalg::GenericOp op, 1345 unsigned idx, bool needsUniv, 1346 llvm::BitVector &induction, 1347 scf::WhileOp whileOp) { 1348 Location loc = op.getLoc(); 1349 // Finalize each else branch of all if statements. 1350 if (codegen.redVal || codegen.expValues) { 1351 while (auto ifOp = dyn_cast_or_null<scf::IfOp>( 1352 rewriter.getInsertionBlock()->getParentOp())) { 1353 unsigned y = 0; 1354 SmallVector<Value, 4> yields; 1355 if (codegen.redVal) { 1356 yields.push_back(codegen.redVal); 1357 updateReduc(merger, codegen, ifOp.getResult(y++)); 1358 } 1359 if (codegen.expValues) { 1360 yields.push_back(codegen.expCount); 1361 codegen.expCount = ifOp->getResult(y++); 1362 } 1363 assert(y == yields.size()); 1364 rewriter.create<scf::YieldOp>(loc, yields); 1365 rewriter.setInsertionPointAfter(ifOp); 1366 } 1367 } 1368 rewriter.setInsertionPointToEnd(&whileOp.getAfter().front()); 1369 // Finalize the induction. Note that the induction could be performed 1370 // in the individual if-branches to avoid re-evaluating the conditions. 1371 // However, that would result in a rather elaborate forest of yield 1372 // instructions during code generation. Moreover, performing the induction 1373 // after the if-statements more closely resembles code generated by TACO. 1374 unsigned o = 0; 1375 SmallVector<Value, 4> operands; 1376 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 1377 for (unsigned b = 0, be = induction.size(); b < be; b++) { 1378 if (induction[b] && merger.isDim(b, Dim::kSparse)) { 1379 unsigned tensor = merger.tensor(b); 1380 assert(idx == merger.index(b)); 1381 Value op1 = codegen.idxs[tensor][idx]; 1382 Value op2 = codegen.loops[idx]; 1383 Value op3 = codegen.pidxs[tensor][idx]; 1384 Value cmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 1385 op1, op2); 1386 Value add = rewriter.create<arith::AddIOp>(loc, op3, one); 1387 operands.push_back(rewriter.create<SelectOp>(loc, cmp, add, op3)); 1388 codegen.pidxs[tensor][idx] = whileOp->getResult(o++); 1389 } 1390 } 1391 if (codegen.redVal) { 1392 operands.push_back(codegen.redVal); 1393 updateReduc(merger, codegen, whileOp->getResult(o++)); 1394 } 1395 if (codegen.expValues) { 1396 operands.push_back(codegen.expCount); 1397 codegen.expCount = whileOp->getResult(o++); 1398 } 1399 if (needsUniv) { 1400 operands.push_back( 1401 rewriter.create<arith::AddIOp>(loc, codegen.loops[idx], one)); 1402 codegen.loops[idx] = whileOp->getResult(o++); 1403 } 1404 assert(o == operands.size()); 1405 rewriter.create<scf::YieldOp>(loc, operands); 1406 rewriter.setInsertionPointAfter(whileOp); 1407 } 1408 1409 /// Generates the induction structure for a for-loop. 1410 static void genForInduction(Merger &merger, CodeGen &codegen, 1411 PatternRewriter &rewriter, linalg::GenericOp op, 1412 Operation *loop) { 1413 Location loc = op.getLoc(); 1414 unsigned o = 0; 1415 SmallVector<Value, 4> operands; 1416 if (codegen.redVal) { 1417 operands.push_back(codegen.redVal); 1418 updateReduc(merger, codegen, loop->getResult(o++)); 1419 } 1420 if (codegen.expValues) { 1421 operands.push_back(codegen.expCount); 1422 codegen.expCount = loop->getResult(o++); 1423 } 1424 assert(o == operands.size()); 1425 if (o > 0) 1426 rewriter.create<scf::YieldOp>(loc, operands); 1427 rewriter.setInsertionPointAfter(loop); 1428 } 1429 1430 /// Generates a single if-statement within a while-loop. 1431 static scf::IfOp genIf(Merger &merger, CodeGen &codegen, 1432 PatternRewriter &rewriter, linalg::GenericOp op, 1433 unsigned idx, llvm::BitVector &conditions) { 1434 Location loc = op.getLoc(); 1435 SmallVector<Type, 4> types; 1436 Value cond; 1437 for (unsigned b = 0, be = conditions.size(); b < be; b++) { 1438 if (conditions[b]) { 1439 unsigned tensor = merger.tensor(b); 1440 assert(idx == merger.index(b)); 1441 Value clause; 1442 if (merger.isDim(b, Dim::kSparse)) { 1443 Value op1 = codegen.idxs[tensor][idx]; 1444 Value op2 = codegen.loops[idx]; 1445 clause = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 1446 op1, op2); 1447 } else { 1448 clause = rewriter.create<arith::ConstantIntOp>(loc, 1, 1); // true 1449 } 1450 cond = cond ? rewriter.create<arith::AndIOp>(loc, cond, clause) : clause; 1451 } 1452 } 1453 if (codegen.redVal) 1454 types.push_back(codegen.redVal.getType()); 1455 if (codegen.expValues) 1456 types.push_back(rewriter.getIndexType()); 1457 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, types, cond, /*else=*/true); 1458 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); 1459 return ifOp; 1460 } 1461 1462 /// Generates end of true branch of if-statement within a while-loop. 1463 static void endIf(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1464 linalg::GenericOp op, scf::IfOp ifOp, Operation *loop, 1465 Value redInput, Value cntInput) { 1466 SmallVector<Value, 4> operands; 1467 if (codegen.redVal) { 1468 operands.push_back(codegen.redVal); 1469 updateReduc(merger, codegen, redInput); 1470 } 1471 if (codegen.expValues) { 1472 operands.push_back(codegen.expCount); 1473 codegen.expCount = cntInput; 1474 } 1475 if (!operands.empty()) 1476 rewriter.create<scf::YieldOp>(op.getLoc(), operands); 1477 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); 1478 } 1479 1480 //===----------------------------------------------------------------------===// 1481 // Sparse compiler synthesis methods (loop sequence). 1482 //===----------------------------------------------------------------------===// 1483 1484 /// Starts a loop sequence at given level. Returns true if 1485 /// the universal loop index must be maintained at this level. 1486 static bool startLoopSeq(Merger &merger, CodeGen &codegen, 1487 PatternRewriter &rewriter, linalg::GenericOp op, 1488 std::vector<unsigned> &topSort, unsigned exp, 1489 unsigned at, unsigned idx, unsigned ldx, 1490 unsigned lts) { 1491 assert(codegen.curVecLength == 1); 1492 assert(!codegen.loops[idx]); 1493 // Emit invariants at this loop sequence level. 1494 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*atStart=*/true); 1495 // Emit access pattern expansion for sparse tensor output. 1496 genExpansion(merger, codegen, rewriter, op, at, /*atStart=*/true); 1497 // Emit further intitialization at this loop sequence level. 1498 unsigned l0 = merger.set(lts)[0]; 1499 bool needsUniv = 1500 genInit(merger, codegen, rewriter, op, topSort, at, merger.lat(l0).bits); 1501 // Maintain the universal index only if it is actually 1502 // consumed by a subsequent lattice point. 1503 if (needsUniv) { 1504 unsigned lsize = merger.set(lts).size(); 1505 for (unsigned i = 1; i < lsize; i++) { 1506 unsigned li = merger.set(lts)[i]; 1507 if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) 1508 return true; 1509 } 1510 } 1511 return false; 1512 } 1513 1514 /// Starts a single loop in current sequence. 1515 static Operation *startLoop(Merger &merger, CodeGen &codegen, 1516 PatternRewriter &rewriter, linalg::GenericOp op, 1517 std::vector<unsigned> &topSort, unsigned at, 1518 unsigned li, bool needsUniv) { 1519 assert(codegen.curVecLength == 1); 1520 // Emit the for/while-loop control. 1521 Operation *loop = genLoop(merger, codegen, rewriter, op, topSort, at, 1522 needsUniv, merger.lat(li).simple); 1523 // Emit the locals for this loop. 1524 genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, 1525 merger.lat(li).bits); 1526 return loop; 1527 } 1528 1529 /// Ends a single loop in current sequence. Returns new values for needsUniv. 1530 static bool endLoop(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1531 linalg::GenericOp op, Operation *loop, unsigned idx, 1532 unsigned li, bool needsUniv) { 1533 codegen.curVecLength = 1; 1534 // End a while-loop. 1535 if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) { 1536 genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv, 1537 merger.lat(li).bits, whileOp); 1538 return needsUniv; 1539 } 1540 // End a for-loop. 1541 genForInduction(merger, codegen, rewriter, op, loop); 1542 return false; 1543 } 1544 1545 /// Ends a loop sequence at given level. 1546 static void endLoopSeq(Merger &merger, CodeGen &codegen, 1547 PatternRewriter &rewriter, linalg::GenericOp op, 1548 unsigned exp, unsigned at, unsigned idx, unsigned ldx) { 1549 assert(codegen.curVecLength == 1); 1550 codegen.loops[idx] = Value(); 1551 // Bring a pending reduction back from SIMD form when sequence ends. 1552 if (codegen.redVal) 1553 if (auto vtp = codegen.redVal.getType().dyn_cast<VectorType>()) 1554 updateReduc(merger, codegen, 1555 genVectorReducEnd(codegen, rewriter, op.getLoc(), vtp)); 1556 // Unmark bookkeeping of invariants and loop index. 1557 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*atStart=*/false); 1558 // Finalize access pattern expansion for sparse tensor output. 1559 genExpansion(merger, codegen, rewriter, op, at, /*atStart=*/false); 1560 } 1561 1562 /// Recursively generates code while computing iteration lattices in order 1563 /// to manage the complexity of implementing co-iteration over unions 1564 /// and intersections of sparse iterations spaces. 1565 static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1566 linalg::GenericOp op, std::vector<unsigned> &topSort, 1567 unsigned exp, unsigned at) { 1568 // At each leaf, assign remaining tensor (sub)expression to output tensor. 1569 if (at == topSort.size()) { 1570 Value rhs = genExp(merger, codegen, rewriter, op, exp); 1571 genTensorStore(merger, codegen, rewriter, op, rhs); 1572 return; 1573 } 1574 1575 // Construct iteration lattices for current loop index, with L0 at top. 1576 unsigned idx = topSort[at]; 1577 unsigned ldx = at == 0 ? -1u : topSort[at - 1]; 1578 unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx)); 1579 1580 // Start a loop sequence. 1581 bool needsUniv = startLoopSeq(merger, codegen, rewriter, op, topSort, exp, at, 1582 idx, ldx, lts); 1583 1584 // Emit a loop for every lattice point L0 >= Li in this loop sequence. 1585 unsigned lsize = merger.set(lts).size(); 1586 for (unsigned i = 0; i < lsize; i++) { 1587 // Start a loop. 1588 unsigned li = merger.set(lts)[i]; 1589 Operation *loop = 1590 startLoop(merger, codegen, rewriter, op, topSort, at, li, needsUniv); 1591 1592 // Visit all lattices points with Li >= Lj to generate the 1593 // loop-body, possibly with if statements for coiteration. 1594 Value redInput = codegen.redVal; 1595 Value cntInput = codegen.expCount; 1596 bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr; 1597 for (unsigned j = 0; j < lsize; j++) { 1598 unsigned lj = merger.set(lts)[j]; 1599 unsigned ej = merger.lat(lj).exp; 1600 if (li == lj || merger.latGT(li, lj)) { 1601 // Recurse into body of each branch. 1602 if (isWhile) { 1603 scf::IfOp ifOp = 1604 genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple); 1605 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1606 endIf(merger, codegen, rewriter, op, ifOp, loop, redInput, cntInput); 1607 } else { 1608 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1609 } 1610 } 1611 } 1612 1613 // End a loop. 1614 needsUniv = 1615 endLoop(merger, codegen, rewriter, op, loop, idx, li, needsUniv); 1616 } 1617 1618 // End a loop sequence. 1619 endLoopSeq(merger, codegen, rewriter, op, exp, at, idx, ldx); 1620 } 1621 1622 /// Converts the result computed by the sparse kernel into the required form. 1623 static void genResult(Merger &merger, CodeGen &codegen, 1624 PatternRewriter &rewriter, linalg::GenericOp op) { 1625 OpOperand *lhs = op.getOutputOperand(0); 1626 Type resType = lhs->get().getType(); 1627 Value result; 1628 if (getSparseTensorEncoding(resType)) { 1629 // The sparse tensor rematerializes from the original sparse tensor's 1630 // underlying sparse storage format. 1631 rewriter.replaceOpWithNewOp<LoadOp>(op, resType, lhs->get(), 1632 codegen.sparseOut == lhs); 1633 } else { 1634 // To rematerialize an non-annotated tensor, simply load it 1635 // from the bufferized value. 1636 Value val = codegen.buffers.back(); // value array 1637 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val); 1638 } 1639 } 1640 1641 //===----------------------------------------------------------------------===// 1642 // Sparse compiler rewriting methods. 1643 //===----------------------------------------------------------------------===// 1644 1645 namespace { 1646 1647 /// Sparse rewriting rule for generic Lingalg operation. 1648 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> { 1649 public: 1650 GenericOpSparsifier(MLIRContext *context, SparsificationOptions o) 1651 : OpRewritePattern<linalg::GenericOp>(context), options(o) {} 1652 1653 LogicalResult matchAndRewrite(linalg::GenericOp op, 1654 PatternRewriter &rewriter) const override { 1655 // Detects sparse annotations and translate the per-dimension sparsity 1656 // information for all tensors to loop indices in the kernel. 1657 assert(op.getNumOutputs() == 1); 1658 unsigned numTensors = op.getNumInputsAndOutputs(); 1659 unsigned numLoops = op.iterator_types().getValue().size(); 1660 Merger merger(numTensors, numLoops); 1661 if (!findSparseAnnotations(merger, op)) 1662 return failure(); 1663 1664 // Computes a topologically sorted iteration graph to ensure 1665 // tensors are visited in natural index order. Fails on cycles. 1666 // This assumes that higher-level passes have already put the 1667 // tensors in each tensor expression in a feasible order. 1668 std::vector<unsigned> topSort; 1669 if (!computeIterationGraph(merger, op, topSort, 1670 SortMask::kIncludeUndef | 1671 SortMask::kIncludeDense) && 1672 !computeIterationGraph(merger, op, topSort, SortMask::kIncludeUndef) && 1673 !computeIterationGraph(merger, op, topSort, SortMask::kIncludeDense) && 1674 !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly)) 1675 return failure(); 1676 1677 // Builds the tensor expression for the Linalg operation in SSA form. 1678 Optional<unsigned> optExp = merger.buildTensorExpFromLinalg(op); 1679 if (!optExp.hasValue()) 1680 return failure(); 1681 unsigned exp = optExp.getValue(); 1682 1683 // Rejects an inadmissable tensor expression. 1684 OpOperand *sparseOut = nullptr; 1685 unsigned outerParNest = 0; 1686 if (!isAdmissableTensorExp(merger, op, topSort, exp, &sparseOut, 1687 outerParNest)) 1688 return failure(); 1689 1690 // Recursively generates code. 1691 merger.setHasSparseOut(sparseOut != nullptr); 1692 CodeGen codegen(options, numTensors, numLoops, sparseOut, outerParNest); 1693 genBuffers(merger, codegen, rewriter, op); 1694 genStmt(merger, codegen, rewriter, op, topSort, exp, 0); 1695 genResult(merger, codegen, rewriter, op); 1696 return success(); 1697 } 1698 1699 private: 1700 /// Options to control sparse code generation. 1701 SparsificationOptions options; 1702 }; 1703 1704 } // namespace 1705 1706 /// Populates the given patterns list with rewriting rules required for 1707 /// the sparsification of linear algebra operations. 1708 void mlir::populateSparsificationPatterns( 1709 RewritePatternSet &patterns, const SparsificationOptions &options) { 1710 patterns.add<GenericOpSparsifier>(patterns.getContext(), options); 1711 } 1712