1 //===- Sparsification.cpp - Implementation of sparsification --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements converting sparse tensor types to actual sparse code. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/Linalg/Utils/Utils.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/Dialect/SCF/SCF.h" 20 #include "mlir/Dialect/SCF/Transforms.h" 21 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 22 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 23 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 24 #include "mlir/Dialect/StandardOps/IR/Ops.h" 25 #include "mlir/Dialect/Vector/VectorOps.h" 26 #include "mlir/IR/Matchers.h" 27 #include "mlir/IR/TensorEncoding.h" 28 #include "llvm/ADT/SmallBitVector.h" 29 30 using namespace mlir; 31 using namespace mlir::sparse_tensor; 32 33 //===----------------------------------------------------------------------===// 34 // Declarations of data structures. 35 //===----------------------------------------------------------------------===// 36 37 namespace { 38 39 // Iteration graph sorting. 40 enum SortMask { kSparseOnly = 0x0, kIncludeDense = 0x1, kIncludeUndef = 0x2 }; 41 42 // Reduction kinds. 43 enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor }; 44 45 // Code generation. 46 struct CodeGen { 47 CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops, 48 OpOperand *op) 49 : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors), 50 pointers(numTensors, std::vector<Value>(numLoops)), 51 indices(numTensors, std::vector<Value>(numLoops)), 52 highs(numTensors, std::vector<Value>(numLoops)), 53 pidxs(numTensors, std::vector<Value>(numLoops)), 54 idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(), 55 redKind(kNoReduc), sparseOut(op), lexIdx(), curVecLength(1), 56 curVecMask() {} 57 /// Sparsification options. 58 SparsificationOptions options; 59 /// Universal dense indices and upper bounds (by index). The loops array 60 /// is updated with the value of the universal dense index in the current 61 /// loop. The sizes array is set once with the inferred dimension sizes. 62 std::vector<Value> loops; 63 std::vector<Value> sizes; 64 /// Buffers for storing dense and sparse numerical values (by tensor). 65 /// This array is set once during bufferization of all tensors. 66 std::vector<Value> buffers; 67 /// Sparse storage schemes (1-D): pointers and indices (by tensor and index). 68 /// This array is set once during bufferization of all sparse tensors. 69 std::vector<std::vector<Value>> pointers; 70 std::vector<std::vector<Value>> indices; 71 /// Sparse iteration information (by tensor and index). These arrays 72 /// are updated to remain current within the current loop. 73 std::vector<std::vector<Value>> highs; 74 std::vector<std::vector<Value>> pidxs; 75 std::vector<std::vector<Value>> idxs; 76 /// Current reduction, updated during code generation. When indices of a 77 /// reduction are exhausted, all inner loops can use a scalarized reduction. 78 unsigned redExp; 79 Value redVal; 80 Reduction redKind; 81 // Sparse tensor as output. 82 OpOperand *sparseOut; 83 Value lexIdx; 84 // Current vector length and mask. 85 unsigned curVecLength; 86 Value curVecMask; 87 }; 88 89 } // namespace 90 91 //===----------------------------------------------------------------------===// 92 // Sparse compiler analysis methods. 93 //===----------------------------------------------------------------------===// 94 95 /// Helper method to apply dimension ordering permutation. 96 static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d) { 97 if (enc) { 98 auto order = enc.getDimOrdering(); 99 if (order) { 100 assert(order.isPermutation()); 101 return order.getDimPosition(d); 102 } 103 } 104 return d; 105 } 106 107 /// Helper method to translate dim level type to internal representation. 108 static Dim toDim(const SparseTensorEncodingAttr &enc, unsigned d) { 109 if (enc) { 110 SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d]; 111 if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed) 112 return Dim::kSparse; 113 if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton) 114 return Dim::kSingle; 115 } 116 return Dim::kDense; 117 } 118 119 /// Helper method to inspect affine expressions. Rejects cases where the 120 /// same index is used more than once. Also rejects affine expressions 121 /// that are not a direct index for annotated tensors. 122 // TODO: accept more affine cases for sparse tensors 123 static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a, Dim dim, 124 bool isDense) { 125 switch (a.getKind()) { 126 case AffineExprKind::DimId: { 127 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 128 if (!merger.isDim(tensor, idx, Dim::kUndef)) 129 return false; // used more than once 130 merger.setDim(tensor, idx, dim); 131 return true; 132 } 133 case AffineExprKind::Add: 134 case AffineExprKind::Mul: { 135 if (!isDense) 136 return false; 137 auto binOp = a.cast<AffineBinaryOpExpr>(); 138 return findAffine(merger, tensor, binOp.getLHS(), dim, isDense) && 139 findAffine(merger, tensor, binOp.getRHS(), dim, isDense); 140 } 141 case AffineExprKind::Constant: 142 return isDense; 143 default: 144 return false; 145 } 146 } 147 148 /// Helper method to inspect sparse encodings in the tensor types. 149 /// Fills the per-dimension sparsity information for all tensors. 150 /// Returns true if the sparse annotations and affine subscript 151 /// expressions of all tensors are admissable. Returns false if 152 /// no annotations are found or inadmissable constructs occur. 153 static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) { 154 bool annotated = false; 155 for (OpOperand *t : op.getInputAndOutputOperands()) { 156 auto map = op.getTiedIndexingMap(t); 157 auto enc = getSparseTensorEncoding(t->get().getType()); 158 if (enc) 159 annotated = true; 160 assert(map.getNumResults() == op.getRank(t)); 161 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 162 unsigned tensor = t->getOperandNumber(); 163 AffineExpr a = map.getResult(perm(enc, d)); 164 if (!findAffine(merger, tensor, a, toDim(enc, d), !enc)) 165 return false; // inadmissable affine expression 166 } 167 } 168 return annotated; 169 } 170 171 /// A DFS helper to compute a topological sort. Note that recursion is 172 /// bounded by the number of implicit loops, which is always small. 173 /// Returns false when a cycle is detected. 174 static bool topSortDFS(unsigned i, std::vector<unsigned> &visit, 175 std::vector<unsigned> &topSort, 176 std::vector<std::vector<bool>> &adjM) { 177 if (visit[i] != 0) 178 return visit[i] != 1; // 1 denotes cycle! 179 visit[i] = 1; 180 for (unsigned j = 0, e = visit.size(); j < e; j++) 181 if (adjM[i][j]) 182 if (!topSortDFS(j, visit, topSort, adjM)) 183 return false; 184 visit[i] = 2; 185 topSort.push_back(i); 186 return true; 187 } 188 189 /// Helper method to add all constraints from the indices in one affine 190 /// expression before all indices in the other affine expression. For 191 /// example i0+i1 < i2+i3+1 yields i0<i2, i0<i3, i1<i2, and i1<i3. 192 static void addAffineOrderings(std::vector<std::vector<bool>> &adjM, 193 AffineExpr a, AffineExpr b, unsigned fidx) { 194 switch (a.getKind()) { 195 case AffineExprKind::DimId: { 196 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 197 if (b) 198 addAffineOrderings(adjM, b, AffineExpr(), idx); 199 else 200 adjM[fidx][idx] = true; 201 break; 202 } 203 case AffineExprKind::Add: 204 case AffineExprKind::Mul: { 205 auto binOp = a.cast<AffineBinaryOpExpr>(); 206 addAffineOrderings(adjM, binOp.getLHS(), b, fidx); 207 addAffineOrderings(adjM, binOp.getRHS(), b, fidx); 208 break; 209 } 210 default: 211 break; 212 } 213 } 214 215 /// Computes a topologically sorted iteration graph for the linalg operation. 216 /// Ensures all tensors are visited in natural index order. This is essential 217 /// for sparse storage formats since these only support access along fixed 218 /// dimensions. Even for dense storage formats, however, the natural index 219 /// order yields innermost unit-stride access with better spatial locality. 220 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op, 221 std::vector<unsigned> &topSort, 222 unsigned mask) { 223 // Set up an n x n from/to adjacency matrix of the iteration graph 224 // for the implicit loop indices i_0 .. i_n-1. 225 unsigned n = op.getNumLoops(); 226 std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false)); 227 228 // Iterate over the indexing maps of every tensor in the tensor expression. 229 for (OpOperand *t : op.getInputAndOutputOperands()) { 230 auto map = op.getTiedIndexingMap(t); 231 auto enc = getSparseTensorEncoding(t->get().getType()); 232 assert(map.getNumDims() == n); 233 // Skip dense tensor constraints when not requested. 234 if (!(mask & SortMask::kIncludeDense) && !enc) 235 continue; 236 // Each tensor expression and optional dimension ordering (row-major 237 // by default) puts an ordering constraint on the loop indices. For 238 // example, the tensor expresion A_ijk forces the ordering i < j < k 239 // on the loop indices if no explicit dimension ordering is given. 240 for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) { 241 AffineExpr f = map.getResult(perm(enc, d - 1)); 242 AffineExpr t = map.getResult(perm(enc, d)); 243 addAffineOrderings(adjM, f, t, 0); 244 } 245 // Push unrelated loops into sparse iteration space, so these 246 // will be skipped more often. 247 if (mask & SortMask::kIncludeUndef) { 248 unsigned tensor = t->getOperandNumber(); 249 for (unsigned i = 0; i < n; i++) 250 if (merger.isDim(tensor, i, Dim::kSparse)) 251 for (unsigned j = 0; j < n; j++) 252 if (merger.isDim(tensor, j, Dim::kUndef)) 253 adjM[i][j] = true; 254 } 255 } 256 257 // Topologically sort the iteration graph to determine loop order. 258 // Report failure for a cyclic iteration graph. 259 topSort.clear(); 260 topSort.reserve(n); 261 std::vector<unsigned> visit(n, 0); 262 for (unsigned i = 0; i < n; i++) 263 if (visit[i] == 0) 264 if (!topSortDFS(i, visit, topSort, adjM)) 265 return false; // cycle! 266 std::reverse(std::begin(topSort), std::end(topSort)); 267 return true; 268 } 269 270 /// Returns true if tensor has an in-place annotation. 271 static bool isInPlace(Value val) { 272 if (auto arg = val.dyn_cast<BlockArgument>()) 273 if (auto funcOp = dyn_cast<FuncOp>(arg.getOwner()->getParentOp())) 274 if (auto attr = funcOp.getArgAttrOfType<BoolAttr>( 275 arg.getArgNumber(), 276 linalg::comprehensive_bufferize::BufferizableOpInterface:: 277 kInplaceableAttrName)) 278 return attr.getValue(); 279 return false; 280 } 281 282 /// Returns true if tensor materializes uninitialized into the computation. 283 static bool isMaterializing(Value val) { 284 return val.getDefiningOp<linalg::InitTensorOp>() || 285 val.getDefiningOp<InitOp>(); 286 } 287 288 /// Returns true when the tensor expression is admissable for codegen. 289 /// Since all sparse input tensors are admissable, we just need to check 290 /// whether the output tensor in the tensor expression codegen is admissable. 291 /// Sets `sparseOut` when a "truly dynamic" sparse tensor output occurs. 292 static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op, 293 unsigned exp, OpOperand **sparseOut) { 294 OpOperand *lhs = op.getOutputOperand(0); 295 unsigned tensor = lhs->getOperandNumber(); 296 auto enc = getSparseTensorEncoding(lhs->get().getType()); 297 // An non-annotated output tensor is assumed dense, and becomes a random 298 // access n-dim memref. Admissable since insertions cannot occur. 299 if (!enc) 300 return true; 301 // An all-dense annotated "sparse" output tensor becomes a linearized random 302 // access 1-dim memref. Also admissable since insertions cannot occur. 303 bool allDense = true; 304 unsigned numLoops = op.iterator_types().getValue().size(); 305 for (unsigned i = 0; i < numLoops; i++) 306 if (merger.isDim(tensor, i, Dim::kSparse)) { 307 allDense = false; 308 break; 309 } 310 if (allDense) 311 return true; 312 // A tensor expression with a sparse output tensor that changes its values 313 // but not its nonzero structure, an operation called "simply dynamic" in 314 // [Bik96,Ch9], is also admissable without special codegen, provided 315 // the tensor's underlying sparse storage scheme can be modified in place. 316 if (merger.isConjunction(tensor, exp) && isInPlace(lhs->get())) 317 return true; 318 // Accept "truly dynamic" if the output tensor materializes uninitialized 319 // into the computation and insertions occur in lexicographic index order. 320 if (isMaterializing(lhs->get())) { 321 // In this first sparse tensor output implementation, this is enforced by 322 // rejecting any reduction loops (since the sparse parallel loops give a 323 // lexicographically sorted and injective view into that tensor). 324 // TODO: generalize to include reductions 325 for (auto attr : op.iterator_types()) 326 if (isReductionIterator(attr)) 327 return false; 328 *sparseOut = lhs; 329 return true; 330 } 331 return false; 332 } 333 334 //===----------------------------------------------------------------------===// 335 // Sparse compiler synthesis methods (reductions). 336 //===----------------------------------------------------------------------===// 337 338 /// Maps reduction kind to name encoding. 339 static StringRef getReductionName(Reduction kind) { 340 switch (kind) { 341 case kNoReduc: 342 break; 343 case kSum: 344 return "add"; 345 case kProduct: 346 return "mul"; 347 case kAnd: 348 return "and"; 349 case kOr: 350 return "or"; 351 case kXor: 352 return "xor"; 353 } 354 llvm_unreachable("unknown reduction kind"); 355 } 356 357 /// Maps operation to reduction. 358 static Reduction getReduction(Kind kind) { 359 switch (kind) { 360 case Kind::kAddF: 361 case Kind::kAddI: 362 case Kind::kSubF: 363 case Kind::kSubI: 364 return kSum; 365 case Kind::kMulF: 366 case Kind::kMulI: 367 return kProduct; 368 case Kind::kAndI: 369 return kAnd; 370 case Kind::kOrI: 371 return kOr; 372 case Kind::kXorI: 373 return kXor; 374 default: 375 llvm_unreachable("unexpected reduction operator"); 376 } 377 } 378 379 /// Generates an initial value for a vector reduction, following the scheme 380 /// given in Chapter 5 of "The Software Vectorization Handbook", where the 381 /// initial scalar value is correctly embedded in the vector reduction value, 382 /// and a straightforward horizontal reduction will complete the operation. 383 static Value genVectorReducInit(CodeGen &codegen, PatternRewriter &rewriter, 384 Location loc, VectorType vtp) { 385 Value r = codegen.redVal; 386 switch (codegen.redKind) { 387 case kNoReduc: 388 break; 389 case kSum: 390 case kXor: { 391 // Initialize reduction vector to: | 0 | .. | 0 | r | 392 Attribute zero = rewriter.getZeroAttr(vtp); 393 Value vec = rewriter.create<arith::ConstantOp>(loc, vtp, zero); 394 return rewriter.create<vector::InsertElementOp>( 395 loc, r, vec, rewriter.create<arith::ConstantIndexOp>(loc, 0)); 396 } 397 case kProduct: { 398 // Initialize reduction vector to: | 1 | .. | 1 | r | 399 Type etp = vtp.getElementType(); 400 Attribute one; 401 if (etp.isa<FloatType>()) 402 one = rewriter.getFloatAttr(etp, 1.0); 403 else 404 one = rewriter.getIntegerAttr(etp, 1); 405 Value vec = rewriter.create<arith::ConstantOp>( 406 loc, vtp, DenseElementsAttr::get(vtp, one)); 407 return rewriter.create<vector::InsertElementOp>( 408 loc, r, vec, rewriter.create<arith::ConstantIndexOp>(loc, 0)); 409 } 410 case kAnd: 411 case kOr: 412 // Initialize reduction vector to: | r | .. | r | r | 413 return rewriter.create<vector::BroadcastOp>(loc, vtp, r); 414 } 415 llvm_unreachable("unknown reduction kind"); 416 } 417 418 /// Generates final value for a vector reduction. 419 static Value genVectorReducEnd(CodeGen &codegen, PatternRewriter &rewriter, 420 Location loc, VectorType vtp) { 421 StringRef name = getReductionName(codegen.redKind); 422 StringAttr kind = rewriter.getStringAttr(name); 423 return rewriter.create<vector::ReductionOp>(loc, vtp.getElementType(), kind, 424 codegen.redVal, ValueRange{}); 425 } 426 427 /// Updates scalarized reduction value. 428 static void updateReduc(Merger &merger, CodeGen &codegen, Value reduc) { 429 assert(codegen.redKind != kNoReduc); 430 codegen.redVal = merger.exp(codegen.redExp).val = reduc; 431 } 432 433 //===----------------------------------------------------------------------===// 434 // Sparse compiler synthesis methods (statements and expressions). 435 //===----------------------------------------------------------------------===// 436 437 /// Maps sparse integer option to actual integral storage type. 438 static Type genIntType(PatternRewriter &rewriter, unsigned width) { 439 if (width == 0) 440 return rewriter.getIndexType(); 441 return rewriter.getIntegerType(width); 442 } 443 444 /// Generates buffer for the output tensor. Note that all sparse kernels 445 /// assume that when all elements are written to (viz. x(i) = y(i) * z(i)), 446 /// the output buffer is already initialized to all zeroes and only nonzeroes 447 /// values are computed and written out. For updates (viz. x(i) += y(i) * z(i)), 448 /// only nonzeroes values are used for the updates and no assumption on the 449 /// original contents of the output buffer is necessary.. 450 static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter, 451 linalg::GenericOp op, MemRefType denseTp, 452 ArrayRef<Value> args) { 453 Location loc = op.getLoc(); 454 Value tensor = op.getOutputOperand(0)->get(); 455 // The output tensor simply could materialize from the buffer that will 456 // be generated for the tensor present in the outs() clause. This has 457 // the major advantage that the sparse kernel only updates the nonzero 458 // positions for the output tensor. 459 if (isInPlace(tensor)) 460 return rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 461 // By default, a new buffer is allocated which is initialized to the 462 // tensor defined in the outs() clause. This is always correct but 463 // introduces a dense initialization component that may negatively 464 // impact the running complexity of the sparse kernel. If the tensor 465 // materializes into the computation, we need to preserve the zero 466 // initialization assumption of all sparse output buffers. 467 if (isMaterializing(tensor)) { 468 Type tp = denseTp.getElementType(); 469 Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args); 470 Value zero = 471 rewriter.create<arith::ConstantOp>(loc, tp, rewriter.getZeroAttr(tp)); 472 rewriter.create<linalg::FillOp>(loc, zero, alloc); 473 return alloc; 474 } 475 Value init = rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 476 Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args); 477 rewriter.create<memref::CopyOp>(loc, init, alloc); 478 return alloc; 479 } 480 481 /// Local bufferization of all dense and sparse data structures. 482 /// This code enables testing the first prototype sparse compiler. 483 // TODO: replace this with a proliferated bufferization strategy 484 static void genBuffers(Merger &merger, CodeGen &codegen, 485 PatternRewriter &rewriter, linalg::GenericOp op) { 486 Location loc = op.getLoc(); 487 assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1); 488 // For every tensor, find lower and upper bound on dimensions, set the 489 // same bounds on loop indices, and obtain dense or sparse buffer(s). 490 SmallVector<Value, 4> args; 491 for (OpOperand *t : op.getInputAndOutputOperands()) { 492 unsigned tensor = t->getOperandNumber(); 493 auto shape = op.getShape(t); 494 auto map = op.getTiedIndexingMap(t); 495 auto enc = getSparseTensorEncoding(t->get().getType()); 496 // Scan all dimensions of current tensor. 497 args.clear(); 498 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 499 AffineExpr a = map.getResult(perm(enc, d)); 500 if (a.getKind() != AffineExprKind::DimId) 501 continue; // compound 502 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 503 // Handle sparse storage schemes. 504 if (merger.isDim(tensor, idx, Dim::kSparse)) { 505 auto dynShape = {ShapedType::kDynamicSize}; 506 auto ptrTp = MemRefType::get( 507 dynShape, genIntType(rewriter, enc.getPointerBitWidth())); 508 auto indTp = MemRefType::get( 509 dynShape, genIntType(rewriter, enc.getIndexBitWidth())); 510 Value dim = rewriter.create<arith::ConstantIndexOp>(loc, d); 511 // Generate sparse primitives to obtains pointer and indices. 512 codegen.pointers[tensor][idx] = 513 rewriter.create<ToPointersOp>(loc, ptrTp, t->get(), dim); 514 codegen.indices[tensor][idx] = 515 rewriter.create<ToIndicesOp>(loc, indTp, t->get(), dim); 516 } 517 // Find upper bound in current dimension. 518 unsigned p = perm(enc, d); 519 Value up = linalg::createOrFoldDimOp(rewriter, loc, t->get(), p); 520 if (shape[p] == MemRefType::kDynamicSize) 521 args.push_back(up); 522 assert(codegen.highs[tensor][idx] == nullptr); 523 codegen.sizes[idx] = codegen.highs[tensor][idx] = up; 524 } 525 // Perform the required bufferization. Dense inputs materialize 526 // from the input tensors. Dense outputs need special handling. 527 // Sparse inputs use sparse primitives to obtain the values. 528 // We also accept in-place all-dense annotated "sparse" outputs. 529 Type elementType = getElementTypeOrSelf(t->get().getType()); 530 if (!enc) { 531 // Non-annotated dense tensors. 532 auto denseTp = MemRefType::get(shape, elementType); 533 if (tensor < op.getNumInputs()) 534 codegen.buffers[tensor] = 535 rewriter.create<memref::BufferCastOp>(loc, denseTp, t->get()); 536 else 537 codegen.buffers[tensor] = 538 genOutputBuffer(codegen, rewriter, op, denseTp, args); 539 } else if (t == codegen.sparseOut) { 540 // True sparse output needs a lexIdx array. 541 Value rank = rewriter.create<arith::ConstantIndexOp>(loc, op.getRank(t)); 542 auto dynShape = {ShapedType::kDynamicSize}; 543 auto memTp = MemRefType::get(dynShape, rewriter.getIndexType()); 544 codegen.lexIdx = rewriter.create<memref::AllocaOp>(loc, memTp, rank); 545 } else { 546 // Annotated sparse tensors. 547 auto dynShape = {ShapedType::kDynamicSize}; 548 auto sparseTp = MemRefType::get(dynShape, elementType); 549 codegen.buffers[tensor] = 550 rewriter.create<ToValuesOp>(loc, sparseTp, t->get()); 551 } 552 } 553 } 554 555 /// Constructs vector type. 556 static VectorType vectorType(CodeGen &codegen, Type etp) { 557 return VectorType::get(codegen.curVecLength, etp); 558 } 559 560 /// Constructs vector type from pointer. 561 static VectorType vectorType(CodeGen &codegen, Value ptr) { 562 return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType()); 563 } 564 565 /// Constructs vector iteration mask. 566 static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter, 567 Value iv, Value lo, Value hi, Value step) { 568 Location loc = iv.getLoc(); 569 VectorType mtp = vectorType(codegen, genIntType(rewriter, 1)); 570 // Special case if the vector length evenly divides the trip count (for 571 // example, "for i = 0, 128, 16"). A constant all-true mask is generated 572 // so that all subsequent masked memory operations are immediately folded 573 // into unconditional memory operations. 574 IntegerAttr loInt, hiInt, stepInt; 575 if (matchPattern(lo, m_Constant(&loInt)) && 576 matchPattern(hi, m_Constant(&hiInt)) && 577 matchPattern(step, m_Constant(&stepInt))) { 578 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) 579 return rewriter.create<vector::BroadcastOp>( 580 loc, mtp, rewriter.create<arith::ConstantIntOp>(loc, 1, 1)); 581 } 582 // Otherwise, generate a vector mask that avoids overrunning the upperbound 583 // during vector execution. Here we rely on subsequent loop optimizations to 584 // avoid executing the mask in all iterations, for example, by splitting the 585 // loop into an unconditional vector loop and a scalar cleanup loop. 586 auto minMap = AffineMap::get( 587 /*dimCount=*/2, /*symbolCount=*/1, 588 {rewriter.getAffineSymbolExpr(0), 589 rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)}, 590 rewriter.getContext()); 591 Value end = 592 rewriter.createOrFold<AffineMinOp>(loc, minMap, ValueRange{hi, iv, step}); 593 return rewriter.create<vector::CreateMaskOp>(loc, mtp, end); 594 } 595 596 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi]. 597 static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter, 598 Value ptr, ArrayRef<Value> args) { 599 Location loc = ptr.getLoc(); 600 VectorType vtp = vectorType(codegen, ptr); 601 Value pass = 602 rewriter.create<arith::ConstantOp>(loc, vtp, rewriter.getZeroAttr(vtp)); 603 if (args.back().getType().isa<VectorType>()) { 604 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 605 Value indexVec = args.back(); 606 scalarArgs.back() = rewriter.create<arith::ConstantIndexOp>(loc, 0); 607 return rewriter.create<vector::GatherOp>( 608 loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass); 609 } 610 return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args, 611 codegen.curVecMask, pass); 612 } 613 614 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs. 615 static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter, 616 Value rhs, Value ptr, ArrayRef<Value> args) { 617 Location loc = ptr.getLoc(); 618 if (args.back().getType().isa<VectorType>()) { 619 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 620 Value indexVec = args.back(); 621 scalarArgs.back() = rewriter.create<arith::ConstantIndexOp>(loc, 0); 622 rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, 623 codegen.curVecMask, rhs); 624 return; 625 } 626 rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask, 627 rhs); 628 } 629 630 /// Generates a vectorized invariant. Here we rely on subsequent loop 631 /// optimizations to hoist the invariant broadcast out of the vector loop. 632 static Value genVectorInvariantValue(CodeGen &codegen, 633 PatternRewriter &rewriter, Value val) { 634 VectorType vtp = vectorType(codegen, val.getType()); 635 return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val); 636 } 637 638 /// Generates an affine expression. 639 // 640 // TODO: generalize for sparse tensor subscripts 641 // 642 static Value genAffine(CodeGen &codegen, PatternRewriter &rewriter, 643 AffineExpr a, Location loc) { 644 switch (a.getKind()) { 645 case AffineExprKind::DimId: { 646 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 647 return codegen.loops[idx]; // universal dense index 648 } 649 case AffineExprKind::Add: { 650 auto binOp = a.cast<AffineBinaryOpExpr>(); 651 return rewriter.create<arith::AddIOp>( 652 loc, genAffine(codegen, rewriter, binOp.getLHS(), loc), 653 genAffine(codegen, rewriter, binOp.getRHS(), loc)); 654 } 655 case AffineExprKind::Mul: { 656 auto binOp = a.cast<AffineBinaryOpExpr>(); 657 return rewriter.create<arith::MulIOp>( 658 loc, genAffine(codegen, rewriter, binOp.getLHS(), loc), 659 genAffine(codegen, rewriter, binOp.getRHS(), loc)); 660 } 661 case AffineExprKind::Constant: { 662 int64_t c = a.cast<AffineConstantExpr>().getValue(); 663 return rewriter.create<arith::ConstantIndexOp>(loc, c); 664 } 665 default: 666 llvm_unreachable("unexpected affine subscript"); 667 } 668 } 669 670 /// Generates subscript for load/store on a dense or sparse tensor. 671 static Value genSubscript(CodeGen &codegen, PatternRewriter &rewriter, 672 linalg::GenericOp op, OpOperand *t, 673 SmallVector<Value, 4> &args) { 674 unsigned tensor = t->getOperandNumber(); 675 auto map = op.getTiedIndexingMap(t); 676 auto enc = getSparseTensorEncoding(t->get().getType()); 677 unsigned rank = map.getNumResults(); 678 if (enc) { 679 // Note that currently, all sparse subscripts are simple. 680 // TODO: accept affine too? 681 AffineExpr a = map.getResult(perm(enc, rank - 1)); 682 assert(a.getKind() == AffineExprKind::DimId); 683 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 684 assert(codegen.pidxs[tensor][idx] != nullptr); 685 args.push_back(codegen.pidxs[tensor][idx]); // position index 686 } else { 687 for (unsigned d = 0; d < rank; d++) { 688 AffineExpr a = map.getResult(perm(enc, d)); 689 args.push_back(genAffine(codegen, rewriter, a, op.getLoc())); 690 } 691 } 692 return codegen.buffers[tensor]; 693 } 694 695 /// Generates a load on a dense or sparse tensor. 696 static Value genTensorLoad(Merger &merger, CodeGen &codegen, 697 PatternRewriter &rewriter, linalg::GenericOp op, 698 unsigned exp) { 699 // Test if the load was hoisted to a higher loop nest. 700 Value val = merger.exp(exp).val; 701 if (val) { 702 if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>()) 703 return genVectorInvariantValue(codegen, rewriter, val); 704 return val; 705 } 706 // Actual load. 707 SmallVector<Value, 4> args; 708 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 709 Value ptr = genSubscript(codegen, rewriter, op, t, args); 710 if (codegen.curVecLength > 1) 711 return genVectorLoad(codegen, rewriter, ptr, args); 712 return rewriter.create<memref::LoadOp>(op.getLoc(), ptr, args); 713 } 714 715 /// Generates a store on a dense or sparse tensor. 716 static void genTensorStore(Merger &merger, CodeGen &codegen, 717 PatternRewriter &rewriter, linalg::GenericOp op, 718 Value rhs) { 719 Location loc = op.getLoc(); 720 // Test if this is a scalarized reduction. 721 if (codegen.redVal) { 722 if (codegen.curVecLength > 1) 723 rhs = rewriter.create<SelectOp>(loc, codegen.curVecMask, rhs, 724 codegen.redVal); 725 updateReduc(merger, codegen, rhs); 726 return; 727 } 728 // Insertion. 729 OpOperand *t = op.getOutputOperand(0); 730 if (t == codegen.sparseOut) { 731 rewriter.create<LexInsertOp>(loc, t->get(), codegen.lexIdx, rhs); 732 return; 733 } 734 // Actual store. 735 SmallVector<Value, 4> args; 736 Value ptr = genSubscript(codegen, rewriter, op, t, args); 737 if (codegen.curVecLength > 1) 738 genVectorStore(codegen, rewriter, rhs, ptr, args); 739 else 740 rewriter.create<memref::StoreOp>(loc, rhs, ptr, args); 741 } 742 743 /// Generates a pointer/index load from the sparse storage scheme. Narrower 744 /// data types need to be zero extended before casting the value into the 745 /// index type used for looping and indexing. 746 static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc, 747 Value ptr, Value s) { 748 // See https://llvm.org/docs/GetElementPtr.html for some background on 749 // the complications described below. 750 if (codegen.curVecLength > 1) { 751 // Since the index vector is used in a subsequent gather/scatter operations, 752 // which effectively defines an unsigned pointer + signed index, we must 753 // zero extend the vector to an index width. For 8-bit and 16-bit values, 754 // an 32-bit index width suffices. For 32-bit values, zero extending the 755 // elements into 64-bit loses some performance since the 32-bit indexed 756 // gather/scatter is more efficient than the 64-bit index variant (if the 757 // negative 32-bit index space is unused, the enableSIMDIndex32 flag can 758 // preserve this performance). For 64-bit values, there is no good way 759 // to state that the indices are unsigned, with creates the potential of 760 // incorrect address calculations in the unlikely case we need such 761 // extremely large offsets. 762 Type etp = ptr.getType().cast<MemRefType>().getElementType(); 763 Value vload = genVectorLoad(codegen, rewriter, ptr, {s}); 764 if (!etp.isa<IndexType>()) { 765 if (etp.getIntOrFloatBitWidth() < 32) 766 vload = rewriter.create<arith::ExtUIOp>( 767 loc, vload, vectorType(codegen, genIntType(rewriter, 32))); 768 else if (etp.getIntOrFloatBitWidth() < 64 && 769 !codegen.options.enableSIMDIndex32) 770 vload = rewriter.create<arith::ExtUIOp>( 771 loc, vload, vectorType(codegen, genIntType(rewriter, 64))); 772 } 773 return vload; 774 } 775 // For the scalar case, we simply zero extend narrower indices into 64-bit 776 // values before casting to index without a performance penalty. Here too, 777 // however, indices that already are 64-bit, in theory, cannot express the 778 // full range as explained above. 779 Value load = rewriter.create<memref::LoadOp>(loc, ptr, s); 780 if (!load.getType().isa<IndexType>()) { 781 if (load.getType().getIntOrFloatBitWidth() < 64) 782 load = 783 rewriter.create<arith::ExtUIOp>(loc, load, genIntType(rewriter, 64)); 784 load = 785 rewriter.create<arith::IndexCastOp>(loc, load, rewriter.getIndexType()); 786 } 787 return load; 788 } 789 790 /// Generates an invariant value. 791 static Value genInvariantValue(Merger &merger, CodeGen &codegen, 792 PatternRewriter &rewriter, unsigned exp) { 793 Value val = merger.exp(exp).val; 794 if (codegen.curVecLength > 1) 795 return genVectorInvariantValue(codegen, rewriter, val); 796 return val; 797 } 798 799 /// Generates an address computation "sz * p + i". 800 static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter, 801 Location loc, Value size, Value p, Value i) { 802 Value mul = rewriter.create<arith::MulIOp>(loc, size, p); 803 if (auto vtp = i.getType().dyn_cast<VectorType>()) { 804 Value inv = 805 rewriter.create<arith::IndexCastOp>(loc, mul, vtp.getElementType()); 806 mul = genVectorInvariantValue(codegen, rewriter, inv); 807 } 808 return rewriter.create<arith::AddIOp>(loc, mul, i); 809 } 810 811 /// Recursively generates tensor expression. 812 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 813 linalg::GenericOp op, unsigned exp) { 814 Location loc = op.getLoc(); 815 if (exp == -1u) 816 return Value(); 817 if (merger.exp(exp).kind == Kind::kTensor) 818 return genTensorLoad(merger, codegen, rewriter, op, exp); 819 if (merger.exp(exp).kind == Kind::kInvariant) 820 return genInvariantValue(merger, codegen, rewriter, exp); 821 Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0); 822 Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1); 823 return merger.buildExp(rewriter, loc, exp, v0, v1); 824 } 825 826 /// Determines if affine expression is invariant. 827 static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a, 828 unsigned ldx, bool &atLevel) { 829 switch (a.getKind()) { 830 case AffineExprKind::DimId: { 831 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 832 if (idx == ldx) 833 atLevel = true; 834 return codegen.loops[idx] != nullptr; // no longer in play? 835 } 836 case AffineExprKind::Add: 837 case AffineExprKind::Mul: { 838 auto binOp = a.cast<AffineBinaryOpExpr>(); 839 return isInvariantAffine(codegen, binOp.getLHS(), ldx, atLevel) && 840 isInvariantAffine(codegen, binOp.getRHS(), ldx, atLevel); 841 } 842 default: 843 return true; 844 } 845 } 846 847 /// Hoists loop invariant tensor loads for which indices have been exhausted. 848 static void genInvariants(Merger &merger, CodeGen &codegen, 849 PatternRewriter &rewriter, linalg::GenericOp op, 850 unsigned exp, unsigned ldx, bool atStart, 851 Kind last = Kind::kTensor) { 852 if (exp == -1u) 853 return; 854 if (merger.exp(exp).kind == Kind::kTensor) { 855 // Inspect tensor indices. 856 bool atLevel = ldx == -1u; 857 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 858 auto map = op.getTiedIndexingMap(t); 859 auto enc = getSparseTensorEncoding(t->get().getType()); 860 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 861 AffineExpr a = map.getResult(perm(enc, d)); 862 if (!isInvariantAffine(codegen, a, ldx, atLevel)) 863 return; // still in play 864 } 865 // All exhausted at this level (atLevel denotes exactly at this level). 866 if (!atLevel) 867 return; 868 OpOperand *lhs = op.getOutputOperand(0); 869 if (lhs == t) { 870 // Start or end a scalarized reduction 871 if (atStart) { 872 Value load = genTensorLoad(merger, codegen, rewriter, op, exp); 873 codegen.redKind = getReduction(last); 874 codegen.redExp = exp; 875 updateReduc(merger, codegen, load); 876 } else { 877 Value redVal = codegen.redVal; 878 updateReduc(merger, codegen, Value()); 879 codegen.redExp = -1u; 880 codegen.redKind = kNoReduc; 881 genTensorStore(merger, codegen, rewriter, op, redVal); 882 } 883 } else { 884 // Start or end loop invariant hoisting of a tensor load. 885 merger.exp(exp).val = 886 atStart ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value(); 887 } 888 } else if (merger.exp(exp).kind != Kind::kInvariant) { 889 // Traverse into the binary operations. Note that we only hoist 890 // tensor loads, since subsequent MLIR/LLVM passes know how to 891 // deal with all other kinds of derived loop invariants. 892 Kind last = merger.exp(exp).kind; 893 unsigned e0 = merger.exp(exp).children.e0; 894 unsigned e1 = merger.exp(exp).children.e1; 895 genInvariants(merger, codegen, rewriter, op, e0, ldx, atStart, last); 896 genInvariants(merger, codegen, rewriter, op, e1, ldx, atStart, last); 897 } 898 } 899 900 /// Generates initialization code for the subsequent loop sequence at 901 /// current index level. Returns true if the loop sequence needs to 902 /// maintain the universal index. 903 static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 904 linalg::GenericOp op, std::vector<unsigned> &topSort, 905 unsigned at, llvm::BitVector &inits) { 906 bool needsUniv = false; 907 Location loc = op.getLoc(); 908 unsigned idx = topSort[at]; 909 910 // Initialize sparse positions. 911 for (unsigned b = 0, be = inits.size(); b < be; b++) { 912 if (inits[b]) { 913 unsigned tensor = merger.tensor(b); 914 assert(idx == merger.index(b)); 915 if (merger.isDim(b, Dim::kSparse)) { 916 // Initialize sparse index. 917 unsigned pat = at; 918 for (; pat != 0; pat--) { 919 if (codegen.pidxs[tensor][topSort[pat - 1]]) 920 break; 921 } 922 Value ptr = codegen.pointers[tensor][idx]; 923 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 924 Value p0 = (pat == 0) ? rewriter.create<arith::ConstantIndexOp>(loc, 0) 925 : codegen.pidxs[tensor][topSort[pat - 1]]; 926 codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0); 927 Value p1 = rewriter.create<arith::AddIOp>(loc, p0, one); 928 codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1); 929 } else { 930 // Dense index still in play. 931 needsUniv = true; 932 } 933 } 934 } 935 936 // Initialize the universal dense index. 937 codegen.loops[idx] = rewriter.create<arith::ConstantIndexOp>(loc, 0); 938 return needsUniv; 939 } 940 941 /// Returns vectorization strategy. Any implicit inner loop in the Linalg 942 /// operation is a candidate. Whether it is actually converted to SIMD code 943 /// depends on the requested strategy. 944 static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) { 945 switch (codegen.options.vectorizationStrategy) { 946 case SparseVectorizationStrategy::kNone: 947 return false; 948 case SparseVectorizationStrategy::kDenseInnerLoop: 949 return isInner && !isSparse; 950 case SparseVectorizationStrategy::kAnyStorageInnerLoop: 951 return isInner; 952 } 953 llvm_unreachable("unexpected vectorization strategy"); 954 } 955 956 /// Returns parallelization strategy. Any implicit loop in the Linalg operation 957 /// that is marked "parallel" is a candidate. Whether it is actually converted 958 /// to a parallel operation depends on the requested strategy. 959 static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction, 960 bool isSparse, bool isVector) { 961 switch (codegen.options.parallelizationStrategy) { 962 case SparseParallelizationStrategy::kNone: 963 return false; 964 case SparseParallelizationStrategy::kDenseOuterLoop: 965 return isOuter && !isSparse && !isReduction && !isVector; 966 case SparseParallelizationStrategy::kAnyStorageOuterLoop: 967 return isOuter && !isReduction && !isVector; 968 case SparseParallelizationStrategy::kDenseAnyLoop: 969 return !isSparse && !isReduction && !isVector; 970 case SparseParallelizationStrategy::kAnyStorageAnyLoop: 971 return !isReduction && !isVector; 972 } 973 llvm_unreachable("unexpected parallelization strategy"); 974 } 975 976 /// Checks unit stride for dense tensors. The iteration graph may have ignored 977 /// dense access patterns in order to avoid cycles (sparse access patterns are 978 /// always placed innermost), but that means dense access has become strided. 979 /// This prevents effective vectorization. 980 static bool denseUnitStrides(Merger &merger, linalg::GenericOp op, 981 unsigned idx) { 982 for (OpOperand *t : op.getInputAndOutputOperands()) { 983 if (!getSparseTensorEncoding(t->get().getType())) { 984 auto map = op.getTiedIndexingMap(t); 985 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 986 AffineExpr a = map.getResult(d); 987 // Report non-unit stride if innermost index appears at an outer 988 // dimension (true non-unit stride) or if the innermost index appears 989 // in a compound subscript in the innermost dimension. Even if the 990 // latter is unit stride, it does not play well with scatter/gather. 991 // TODO: accept unit stride affine innermost like a[i,j+k+1]? 992 if (a.isFunctionOfDim(idx) && 993 ((d != rank - 1) || (a.getKind() != AffineExprKind::DimId))) 994 return false; 995 } 996 } 997 } 998 return true; 999 } 1000 1001 /// Generates a for-loop on a single index. 1002 static Operation *genFor(Merger &merger, CodeGen &codegen, 1003 PatternRewriter &rewriter, linalg::GenericOp op, 1004 bool isOuter, bool isInner, unsigned idx, 1005 llvm::BitVector &indices) { 1006 unsigned fb = indices.find_first(); 1007 unsigned tensor = merger.tensor(fb); 1008 assert(idx == merger.index(fb)); 1009 auto iteratorTypes = op.iterator_types().getValue(); 1010 bool isReduction = isReductionIterator(iteratorTypes[idx]); 1011 bool isSparse = merger.isDim(fb, Dim::kSparse); 1012 bool isVector = !codegen.sparseOut && 1013 isVectorFor(codegen, isInner, isSparse) && 1014 denseUnitStrides(merger, op, idx); 1015 bool isParallel = 1016 !codegen.sparseOut && 1017 isParallelFor(codegen, isOuter, isReduction, isSparse, isVector); 1018 1019 // Prepare vector length. 1020 if (isVector) 1021 codegen.curVecLength = codegen.options.vectorLength; 1022 1023 // Loop bounds and increment. 1024 Location loc = op.getLoc(); 1025 Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx]; 1026 Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx]; 1027 Value step = 1028 rewriter.create<arith::ConstantIndexOp>(loc, codegen.curVecLength); 1029 1030 // Emit a parallel loop. 1031 if (isParallel) { 1032 assert(!isVector); 1033 scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step); 1034 if (isSparse) 1035 codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0]; 1036 else 1037 codegen.loops[idx] = parOp.getInductionVars()[0]; 1038 rewriter.setInsertionPointToStart(parOp.getBody()); 1039 return parOp; 1040 } 1041 1042 // Emit a sequential or vector loop. 1043 SmallVector<Value, 4> operands; 1044 if (codegen.redVal) { 1045 // In a vector loop, bring reduction into SIMD form, if not already. 1046 if (isVector && !codegen.redVal.getType().isa<VectorType>()) { 1047 VectorType vtp = vectorType(codegen, codegen.redVal.getType()); 1048 Value vred = genVectorReducInit(codegen, rewriter, loc, vtp); 1049 updateReduc(merger, codegen, vred); 1050 } 1051 operands.push_back(codegen.redVal); 1052 } 1053 scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands); 1054 if (codegen.redVal) 1055 updateReduc(merger, codegen, forOp.getRegionIterArgs().front()); 1056 // Assign induction variable to sparse or dense index. 1057 Value iv = forOp.getInductionVar(); 1058 if (isSparse) 1059 codegen.pidxs[tensor][idx] = iv; 1060 else 1061 codegen.loops[idx] = iv; 1062 rewriter.setInsertionPointToStart(forOp.getBody()); 1063 // Share vector iteration mask between all subsequent loads/stores. 1064 if (isVector) 1065 codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step); 1066 return forOp; 1067 } 1068 1069 /// Emit a while-loop for co-iteration over multiple indices. 1070 static Operation *genWhile(Merger &merger, CodeGen &codegen, 1071 PatternRewriter &rewriter, linalg::GenericOp op, 1072 unsigned idx, bool needsUniv, 1073 llvm::BitVector &indices) { 1074 SmallVector<Type, 4> types; 1075 SmallVector<Value, 4> operands; 1076 // Construct the while-loop with a parameter for each index. 1077 Type indexType = rewriter.getIndexType(); 1078 for (unsigned b = 0, be = indices.size(); b < be; b++) { 1079 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1080 unsigned tensor = merger.tensor(b); 1081 assert(idx == merger.index(b)); 1082 types.push_back(indexType); 1083 operands.push_back(codegen.pidxs[tensor][idx]); 1084 } 1085 } 1086 if (codegen.redVal) { 1087 types.push_back(codegen.redVal.getType()); 1088 operands.push_back(codegen.redVal); 1089 } 1090 if (needsUniv) { 1091 types.push_back(indexType); 1092 operands.push_back(codegen.loops[idx]); 1093 } 1094 assert(types.size() == operands.size()); 1095 Location loc = op.getLoc(); 1096 scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands); 1097 Block *before = rewriter.createBlock(&whileOp.before(), {}, types); 1098 Block *after = rewriter.createBlock(&whileOp.after(), {}, types); 1099 1100 // Build the "before" region, which effectively consists 1101 // of a conjunction of "i < upper" tests on all induction. 1102 rewriter.setInsertionPointToStart(&whileOp.before().front()); 1103 Value cond; 1104 unsigned o = 0; 1105 for (unsigned b = 0, be = indices.size(); b < be; b++) { 1106 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1107 unsigned tensor = merger.tensor(b); 1108 assert(idx == merger.index(b)); 1109 Value op1 = before->getArgument(o); 1110 Value op2 = codegen.highs[tensor][idx]; 1111 Value opc = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 1112 op1, op2); 1113 cond = cond ? rewriter.create<arith::AndIOp>(loc, cond, opc) : opc; 1114 codegen.pidxs[tensor][idx] = after->getArgument(o++); 1115 } 1116 } 1117 if (codegen.redVal) 1118 updateReduc(merger, codegen, after->getArgument(o++)); 1119 if (needsUniv) 1120 codegen.loops[idx] = after->getArgument(o++); 1121 assert(o == operands.size()); 1122 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 1123 rewriter.setInsertionPointToStart(&whileOp.after().front()); 1124 return whileOp; 1125 } 1126 1127 /// Generates a for-loop or a while-loop, depending on whether it implements 1128 /// singleton iteration or co-iteration over the given conjunction. 1129 static Operation *genLoop(Merger &merger, CodeGen &codegen, 1130 PatternRewriter &rewriter, linalg::GenericOp op, 1131 std::vector<unsigned> &topSort, unsigned at, 1132 bool needsUniv, llvm::BitVector &indices) { 1133 unsigned idx = topSort[at]; 1134 if (indices.count() == 1) { 1135 bool isOuter = at == 0; 1136 bool isInner = at == topSort.size() - 1; 1137 return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx, 1138 indices); 1139 } 1140 return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices); 1141 } 1142 1143 /// Generates the local variables for this loop, consisting of the sparse 1144 /// indices, restored universal dense index, and dense positions. 1145 static void genLocals(Merger &merger, CodeGen &codegen, 1146 PatternRewriter &rewriter, linalg::GenericOp op, 1147 std::vector<unsigned> &topSort, unsigned at, 1148 bool needsUniv, llvm::BitVector &locals) { 1149 Location loc = op.getLoc(); 1150 unsigned idx = topSort[at]; 1151 1152 // Initialize sparse indices. 1153 Value min; 1154 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1155 if (locals[b] && merger.isDim(b, Dim::kSparse)) { 1156 unsigned tensor = merger.tensor(b); 1157 assert(idx == merger.index(b)); 1158 Value ptr = codegen.indices[tensor][idx]; 1159 Value s = codegen.pidxs[tensor][idx]; 1160 Value load = genLoad(codegen, rewriter, loc, ptr, s); 1161 codegen.idxs[tensor][idx] = load; 1162 if (!needsUniv) { 1163 if (min) { 1164 Value cmp = rewriter.create<arith::CmpIOp>( 1165 loc, arith::CmpIPredicate::ult, load, min); 1166 min = rewriter.create<SelectOp>(loc, cmp, load, min); 1167 } else { 1168 min = load; 1169 } 1170 } 1171 } 1172 } 1173 1174 // Merge dense universal index over minimum. 1175 if (min) { 1176 assert(!needsUniv); 1177 codegen.loops[idx] = min; 1178 } 1179 1180 // Initialize dense positions. Note that we generate dense indices of the 1181 // output tensor unconditionally, since they may not appear in the lattice, 1182 // but may be needed for linearized codegen. 1183 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1184 if ((locals[b] || merger.isOutTensor(b, idx)) && 1185 merger.isDim(b, Dim::kDense)) { 1186 unsigned tensor = merger.tensor(b); 1187 assert(idx == merger.index(b)); 1188 unsigned pat = at; 1189 for (; pat != 0; pat--) 1190 if (codegen.pidxs[tensor][topSort[pat - 1]]) 1191 break; 1192 Value p = (pat == 0) ? rewriter.create<arith::ConstantIndexOp>(loc, 0) 1193 : codegen.pidxs[tensor][topSort[pat - 1]]; 1194 codegen.pidxs[tensor][idx] = genAddress( 1195 codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]); 1196 } 1197 } 1198 1199 // Move the insertion indices in lexicographic index order. 1200 if (codegen.sparseOut) { 1201 Value pos = rewriter.create<arith::ConstantIndexOp>(loc, at); 1202 rewriter.create<memref::StoreOp>(loc, codegen.loops[idx], codegen.lexIdx, 1203 pos); 1204 } 1205 } 1206 1207 /// Generates the induction structure for a while-loop. 1208 static void genWhileInduction(Merger &merger, CodeGen &codegen, 1209 PatternRewriter &rewriter, linalg::GenericOp op, 1210 unsigned idx, bool needsUniv, 1211 llvm::BitVector &induction, 1212 scf::WhileOp whileOp) { 1213 Location loc = op.getLoc(); 1214 // Finalize each else branch of all if statements. 1215 if (codegen.redVal) { 1216 while (auto ifOp = dyn_cast_or_null<scf::IfOp>( 1217 rewriter.getInsertionBlock()->getParentOp())) { 1218 rewriter.create<scf::YieldOp>(loc, codegen.redVal); 1219 updateReduc(merger, codegen, ifOp.getResult(0)); 1220 rewriter.setInsertionPointAfter(ifOp); 1221 } 1222 } 1223 rewriter.setInsertionPointToEnd(&whileOp.after().front()); 1224 // Finalize the induction. Note that the induction could be performed 1225 // in the individual if-branches to avoid re-evaluating the conditions. 1226 // However, that would result in a rather elaborate forest of yield 1227 // instructions during code generation. Moreover, performing the induction 1228 // after the if-statements more closely resembles code generated by TACO. 1229 unsigned o = 0; 1230 SmallVector<Value, 4> operands; 1231 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 1232 for (unsigned b = 0, be = induction.size(); b < be; b++) { 1233 if (induction[b] && merger.isDim(b, Dim::kSparse)) { 1234 unsigned tensor = merger.tensor(b); 1235 assert(idx == merger.index(b)); 1236 Value op1 = codegen.idxs[tensor][idx]; 1237 Value op2 = codegen.loops[idx]; 1238 Value op3 = codegen.pidxs[tensor][idx]; 1239 Value cmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 1240 op1, op2); 1241 Value add = rewriter.create<arith::AddIOp>(loc, op3, one); 1242 operands.push_back(rewriter.create<SelectOp>(loc, cmp, add, op3)); 1243 codegen.pidxs[tensor][idx] = whileOp->getResult(o++); 1244 } 1245 } 1246 if (codegen.redVal) { 1247 operands.push_back(codegen.redVal); 1248 updateReduc(merger, codegen, whileOp->getResult(o++)); 1249 } 1250 if (needsUniv) { 1251 operands.push_back( 1252 rewriter.create<arith::AddIOp>(loc, codegen.loops[idx], one)); 1253 codegen.loops[idx] = whileOp->getResult(o++); 1254 } 1255 assert(o == operands.size()); 1256 rewriter.create<scf::YieldOp>(loc, operands); 1257 rewriter.setInsertionPointAfter(whileOp); 1258 } 1259 1260 /// Generates the induction structure for a for-loop. 1261 static void genForInduction(Merger &merger, CodeGen &codegen, 1262 PatternRewriter &rewriter, linalg::GenericOp op, 1263 Operation *loop) { 1264 Location loc = op.getLoc(); 1265 unsigned o = 0; 1266 SmallVector<Value, 4> operands; 1267 if (codegen.redVal) { 1268 operands.push_back(codegen.redVal); 1269 updateReduc(merger, codegen, loop->getResult(o++)); 1270 } 1271 assert(o == operands.size()); 1272 if (o > 0) 1273 rewriter.create<scf::YieldOp>(loc, operands); 1274 rewriter.setInsertionPointAfter(loop); 1275 } 1276 1277 /// Generates a single if-statement within a while-loop. 1278 static scf::IfOp genIf(Merger &merger, CodeGen &codegen, 1279 PatternRewriter &rewriter, linalg::GenericOp op, 1280 unsigned idx, llvm::BitVector &conditions) { 1281 Location loc = op.getLoc(); 1282 SmallVector<Type, 4> types; 1283 Value cond; 1284 for (unsigned b = 0, be = conditions.size(); b < be; b++) { 1285 if (conditions[b]) { 1286 unsigned tensor = merger.tensor(b); 1287 assert(idx == merger.index(b)); 1288 Value clause; 1289 if (merger.isDim(b, Dim::kSparse)) { 1290 Value op1 = codegen.idxs[tensor][idx]; 1291 Value op2 = codegen.loops[idx]; 1292 clause = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 1293 op1, op2); 1294 } else { 1295 clause = rewriter.create<arith::ConstantIntOp>(loc, 1, 1); // true 1296 } 1297 cond = cond ? rewriter.create<arith::AndIOp>(loc, cond, clause) : clause; 1298 } 1299 } 1300 if (codegen.redVal) 1301 types.push_back(codegen.redVal.getType()); 1302 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, types, cond, /*else=*/true); 1303 rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); 1304 return ifOp; 1305 } 1306 1307 /// Generates end of true branch of if-statement within a while-loop. 1308 static void endIf(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1309 linalg::GenericOp op, scf::IfOp ifOp, Value ifInput) { 1310 if (codegen.redVal) { 1311 rewriter.create<scf::YieldOp>(op.getLoc(), codegen.redVal); 1312 updateReduc(merger, codegen, ifInput); 1313 } 1314 rewriter.setInsertionPointToStart(&ifOp.elseRegion().front()); 1315 } 1316 1317 //===----------------------------------------------------------------------===// 1318 // Sparse compiler synthesis methods (loop sequence). 1319 //===----------------------------------------------------------------------===// 1320 1321 /// Starts a loop sequence at given level. Returns true if 1322 /// the universal loop index must be maintained at this level. 1323 static bool startLoopSeq(Merger &merger, CodeGen &codegen, 1324 PatternRewriter &rewriter, linalg::GenericOp op, 1325 std::vector<unsigned> &topSort, unsigned exp, 1326 unsigned at, unsigned idx, unsigned ldx, 1327 unsigned lts) { 1328 assert(codegen.curVecLength == 1); 1329 assert(!codegen.loops[idx]); 1330 // Emit invariants at this loop sequence level. 1331 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*atStart=*/true); 1332 // Emit further intitialization at this loop sequence level. 1333 unsigned l0 = merger.set(lts)[0]; 1334 bool needsUniv = 1335 genInit(merger, codegen, rewriter, op, topSort, at, merger.lat(l0).bits); 1336 // Maintain the universal index only if it is actually 1337 // consumed by a subsequent lattice point. 1338 if (needsUniv) { 1339 unsigned lsize = merger.set(lts).size(); 1340 for (unsigned i = 1; i < lsize; i++) { 1341 unsigned li = merger.set(lts)[i]; 1342 if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) 1343 return true; 1344 } 1345 } 1346 return false; 1347 } 1348 1349 /// Starts a single loop in current sequence. 1350 static Operation *startLoop(Merger &merger, CodeGen &codegen, 1351 PatternRewriter &rewriter, linalg::GenericOp op, 1352 std::vector<unsigned> &topSort, unsigned at, 1353 unsigned li, bool needsUniv) { 1354 assert(codegen.curVecLength == 1); 1355 // Emit the for/while-loop control. 1356 Operation *loop = genLoop(merger, codegen, rewriter, op, topSort, at, 1357 needsUniv, merger.lat(li).simple); 1358 // Emit the locals for this loop. 1359 genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, 1360 merger.lat(li).bits); 1361 return loop; 1362 } 1363 1364 /// Ends a single loop in current sequence. Returns new values for needsUniv. 1365 static bool endLoop(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1366 linalg::GenericOp op, Operation *loop, unsigned idx, 1367 unsigned li, bool needsUniv) { 1368 codegen.curVecLength = 1; 1369 // End a while-loop. 1370 if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) { 1371 genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv, 1372 merger.lat(li).bits, whileOp); 1373 return needsUniv; 1374 } 1375 // End a for-loop. 1376 genForInduction(merger, codegen, rewriter, op, loop); 1377 return false; 1378 } 1379 1380 /// Ends a loop sequence at given level. 1381 static void endLoopSeq(Merger &merger, CodeGen &codegen, 1382 PatternRewriter &rewriter, linalg::GenericOp op, 1383 unsigned exp, unsigned idx, unsigned ldx) { 1384 assert(codegen.curVecLength == 1); 1385 codegen.loops[idx] = Value(); 1386 // Bring a pending reduction back from SIMD form when sequence ends. 1387 if (codegen.redVal) 1388 if (auto vtp = codegen.redVal.getType().dyn_cast<VectorType>()) 1389 updateReduc(merger, codegen, 1390 genVectorReducEnd(codegen, rewriter, op.getLoc(), vtp)); 1391 // Unmark bookkeeping of invariants and loop index. 1392 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*atStart=*/false); 1393 } 1394 1395 /// Recursively generates code while computing iteration lattices in order 1396 /// to manage the complexity of implementing co-iteration over unions 1397 /// and intersections of sparse iterations spaces. 1398 static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1399 linalg::GenericOp op, std::vector<unsigned> &topSort, 1400 unsigned exp, unsigned at) { 1401 // At each leaf, assign remaining tensor (sub)expression to output tensor. 1402 if (at == topSort.size()) { 1403 Value rhs = genExp(merger, codegen, rewriter, op, exp); 1404 genTensorStore(merger, codegen, rewriter, op, rhs); 1405 return; 1406 } 1407 1408 // Construct iteration lattices for current loop index, with L0 at top. 1409 unsigned idx = topSort[at]; 1410 unsigned ldx = at == 0 ? -1u : topSort[at - 1]; 1411 unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx)); 1412 1413 // Start a loop sequence. 1414 bool needsUniv = startLoopSeq(merger, codegen, rewriter, op, topSort, exp, at, 1415 idx, ldx, lts); 1416 1417 // Emit a loop for every lattice point L0 >= Li in this loop sequence. 1418 unsigned lsize = merger.set(lts).size(); 1419 for (unsigned i = 0; i < lsize; i++) { 1420 // Start a loop. 1421 unsigned li = merger.set(lts)[i]; 1422 Operation *loop = 1423 startLoop(merger, codegen, rewriter, op, topSort, at, li, needsUniv); 1424 1425 // Visit all lattices points with Li >= Lj to generate the 1426 // loop-body, possibly with if statements for coiteration. 1427 Value ifInput = codegen.redVal; 1428 bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr; 1429 for (unsigned j = 0; j < lsize; j++) { 1430 unsigned lj = merger.set(lts)[j]; 1431 unsigned ej = merger.lat(lj).exp; 1432 if (li == lj || merger.latGT(li, lj)) { 1433 // Recurse into body of each branch. 1434 if (isWhile) { 1435 scf::IfOp ifOp = 1436 genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple); 1437 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1438 endIf(merger, codegen, rewriter, op, ifOp, ifInput); 1439 } else { 1440 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1441 } 1442 } 1443 } 1444 1445 // End a loop. 1446 needsUniv = 1447 endLoop(merger, codegen, rewriter, op, loop, idx, li, needsUniv); 1448 } 1449 1450 // End a loop sequence. 1451 endLoopSeq(merger, codegen, rewriter, op, exp, idx, ldx); 1452 } 1453 1454 /// Converts the result computed by the sparse kernel into the required form. 1455 static void genResult(Merger &merger, CodeGen &codegen, 1456 PatternRewriter &rewriter, linalg::GenericOp op) { 1457 OpOperand *lhs = op.getOutputOperand(0); 1458 Type resType = lhs->get().getType(); 1459 Value result; 1460 if (getSparseTensorEncoding(resType)) { 1461 // The sparse tensor rematerializes from the original sparse tensor's 1462 // underlying sparse storage format. 1463 rewriter.replaceOpWithNewOp<LoadOp>(op, resType, lhs->get(), 1464 codegen.sparseOut == lhs); 1465 } else { 1466 // To rematerialize an non-annotated tensor, simply load it 1467 // from the bufferized value. 1468 Value val = codegen.buffers.back(); // value array 1469 rewriter.replaceOpWithNewOp<memref::TensorLoadOp>(op, resType, val); 1470 } 1471 } 1472 1473 //===----------------------------------------------------------------------===// 1474 // Sparse compiler rewriting methods. 1475 //===----------------------------------------------------------------------===// 1476 1477 namespace { 1478 1479 /// Sparse rewriting rule for generic Lingalg operation. 1480 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> { 1481 public: 1482 GenericOpSparsifier(MLIRContext *context, SparsificationOptions o) 1483 : OpRewritePattern<linalg::GenericOp>(context), options(o) {} 1484 1485 LogicalResult matchAndRewrite(linalg::GenericOp op, 1486 PatternRewriter &rewriter) const override { 1487 // Detects sparse annotations and translate the per-dimension sparsity 1488 // information for all tensors to loop indices in the kernel. 1489 assert(op.getNumOutputs() == 1); 1490 unsigned numTensors = op.getNumInputsAndOutputs(); 1491 unsigned numLoops = op.iterator_types().getValue().size(); 1492 Merger merger(numTensors, numLoops); 1493 if (!findSparseAnnotations(merger, op)) 1494 return failure(); 1495 1496 // Computes a topologically sorted iteration graph to ensure 1497 // tensors are visited in natural index order. Fails on cycles. 1498 // This assumes that higher-level passes have already put the 1499 // tensors in each tensor expression in a feasible order. 1500 std::vector<unsigned> topSort; 1501 if (!computeIterationGraph(merger, op, topSort, 1502 SortMask::kIncludeUndef | 1503 SortMask::kIncludeDense) && 1504 !computeIterationGraph(merger, op, topSort, SortMask::kIncludeUndef) && 1505 !computeIterationGraph(merger, op, topSort, SortMask::kIncludeDense) && 1506 !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly)) 1507 return failure(); 1508 1509 // Builds the tensor expression for the Linalg operation in SSA form. 1510 Optional<unsigned> optExp = merger.buildTensorExpFromLinalg(op); 1511 if (!optExp.hasValue()) 1512 return failure(); 1513 unsigned exp = optExp.getValue(); 1514 1515 // Rejects an inadmissable tensor expression. 1516 OpOperand *sparseOut = nullptr; 1517 if (!isAdmissableTensorExp(merger, op, exp, &sparseOut)) 1518 return failure(); 1519 1520 // Recursively generates code. 1521 CodeGen codegen(options, numTensors, numLoops, sparseOut); 1522 genBuffers(merger, codegen, rewriter, op); 1523 genStmt(merger, codegen, rewriter, op, topSort, exp, 0); 1524 genResult(merger, codegen, rewriter, op); 1525 return success(); 1526 } 1527 1528 private: 1529 /// Options to control sparse code generation. 1530 SparsificationOptions options; 1531 }; 1532 1533 } // namespace 1534 1535 /// Populates the given patterns list with rewriting rules required for 1536 /// the sparsification of linear algebra operations. 1537 void mlir::populateSparsificationPatterns( 1538 RewritePatternSet &patterns, const SparsificationOptions &options) { 1539 patterns.add<GenericOpSparsifier>(patterns.getContext(), options); 1540 } 1541