1 //===- Sparsification.cpp - Implementation of sparsification --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements converting sparse tensor types to actual sparse code. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 16 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 #include "mlir/Dialect/SCF/SCF.h" 21 #include "mlir/Dialect/SCF/Transforms.h" 22 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 23 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 24 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 25 #include "mlir/Dialect/StandardOps/IR/Ops.h" 26 #include "mlir/Dialect/Vector/VectorOps.h" 27 #include "mlir/IR/Matchers.h" 28 #include "mlir/IR/TensorEncoding.h" 29 #include "llvm/ADT/SmallBitVector.h" 30 31 using namespace mlir; 32 using namespace mlir::sparse_tensor; 33 34 //===----------------------------------------------------------------------===// 35 // Declarations of data structures. 36 //===----------------------------------------------------------------------===// 37 38 namespace { 39 40 // Iteration graph sorting. 41 enum SortMask { kSparseOnly = 0x0, kIncludeDense = 0x1, kIncludeUndef = 0x2 }; 42 43 // Reduction kinds. 44 enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor }; 45 46 // Code generation. 47 struct CodeGen { 48 CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops, 49 OpOperand *op, unsigned nest) 50 : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors), 51 pointers(numTensors, std::vector<Value>(numLoops)), 52 indices(numTensors, std::vector<Value>(numLoops)), 53 highs(numTensors, std::vector<Value>(numLoops)), 54 pidxs(numTensors, std::vector<Value>(numLoops)), 55 idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(), 56 redKind(kNoReduc), sparseOut(op), outerParNest(nest), lexIdx(), 57 curVecLength(1), curVecMask() {} 58 /// Sparsification options. 59 SparsificationOptions options; 60 /// Universal dense indices and upper bounds (by index). The loops array 61 /// is updated with the value of the universal dense index in the current 62 /// loop. The sizes array is set once with the inferred dimension sizes. 63 std::vector<Value> loops; 64 std::vector<Value> sizes; 65 /// Buffers for storing dense and sparse numerical values (by tensor). 66 /// This array is set once during bufferization of all tensors. 67 std::vector<Value> buffers; 68 /// Sparse storage schemes (1-D): pointers and indices (by tensor and index). 69 /// This array is set once during bufferization of all sparse tensors. 70 std::vector<std::vector<Value>> pointers; 71 std::vector<std::vector<Value>> indices; 72 /// Sparse iteration information (by tensor and index). These arrays 73 /// are updated to remain current within the current loop. 74 std::vector<std::vector<Value>> highs; 75 std::vector<std::vector<Value>> pidxs; 76 std::vector<std::vector<Value>> idxs; 77 /// Current reduction, updated during code generation. When indices of a 78 /// reduction are exhausted, all inner loops can use a scalarized reduction. 79 unsigned redExp; 80 Value redVal; 81 Reduction redKind; 82 // Sparse tensor as output. Implemented either through direct injective 83 // insertion in lexicographic index order (where indices are updated 84 // in the temporary array `lexIdx`) or TODO: access pattern expansion 85 OpOperand *sparseOut; 86 unsigned outerParNest; 87 Value lexIdx; 88 // Current vector length and mask. 89 unsigned curVecLength; 90 Value curVecMask; 91 }; 92 93 } // namespace 94 95 //===----------------------------------------------------------------------===// 96 // Sparse compiler analysis methods. 97 //===----------------------------------------------------------------------===// 98 99 /// Helper method to apply dimension ordering permutation. 100 static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d) { 101 if (enc) { 102 auto order = enc.getDimOrdering(); 103 if (order) { 104 assert(order.isPermutation()); 105 return order.getDimPosition(d); 106 } 107 } 108 return d; 109 } 110 111 /// Helper method to translate dim level type to internal representation. 112 static Dim toDim(const SparseTensorEncodingAttr &enc, unsigned d) { 113 if (enc) { 114 SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d]; 115 if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed) 116 return Dim::kSparse; 117 if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton) 118 return Dim::kSingle; 119 } 120 return Dim::kDense; 121 } 122 123 /// Helper method to inspect affine expressions. Rejects cases where the 124 /// same index is used more than once. Also rejects affine expressions 125 /// that are not a direct index for annotated tensors. 126 // TODO: accept more affine cases for sparse tensors 127 static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a, Dim dim, 128 bool isDense) { 129 switch (a.getKind()) { 130 case AffineExprKind::DimId: { 131 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 132 if (!merger.isDim(tensor, idx, Dim::kUndef)) 133 return false; // used more than once 134 merger.setDim(tensor, idx, dim); 135 return true; 136 } 137 case AffineExprKind::Add: 138 case AffineExprKind::Mul: { 139 if (!isDense) 140 return false; 141 auto binOp = a.cast<AffineBinaryOpExpr>(); 142 return findAffine(merger, tensor, binOp.getLHS(), dim, isDense) && 143 findAffine(merger, tensor, binOp.getRHS(), dim, isDense); 144 } 145 case AffineExprKind::Constant: 146 return isDense; 147 default: 148 return false; 149 } 150 } 151 152 /// Helper method to inspect sparse encodings in the tensor types. 153 /// Fills the per-dimension sparsity information for all tensors. 154 /// Returns true if the sparse annotations and affine subscript 155 /// expressions of all tensors are admissable. Returns false if 156 /// no annotations are found or inadmissable constructs occur. 157 static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) { 158 bool annotated = false; 159 for (OpOperand *t : op.getInputAndOutputOperands()) { 160 auto map = op.getTiedIndexingMap(t); 161 auto enc = getSparseTensorEncoding(t->get().getType()); 162 if (enc) 163 annotated = true; 164 assert(map.getNumResults() == op.getRank(t)); 165 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 166 unsigned tensor = t->getOperandNumber(); 167 AffineExpr a = map.getResult(perm(enc, d)); 168 if (!findAffine(merger, tensor, a, toDim(enc, d), !enc)) 169 return false; // inadmissable affine expression 170 } 171 } 172 return annotated; 173 } 174 175 /// A DFS helper to compute a topological sort. Note that recursion is 176 /// bounded by the number of implicit loops, which is always small. 177 /// Returns false when a cycle is detected. 178 static bool topSortDFS(unsigned i, std::vector<unsigned> &visit, 179 std::vector<unsigned> &topSort, 180 std::vector<std::vector<bool>> &adjM) { 181 if (visit[i] != 0) 182 return visit[i] != 1; // 1 denotes cycle! 183 visit[i] = 1; 184 for (unsigned j = 0, e = visit.size(); j < e; j++) 185 if (adjM[i][j]) 186 if (!topSortDFS(j, visit, topSort, adjM)) 187 return false; 188 visit[i] = 2; 189 topSort.push_back(i); 190 return true; 191 } 192 193 /// Helper method to add all constraints from the indices in one affine 194 /// expression before all indices in the other affine expression. For 195 /// example i0+i1 < i2+i3+1 yields i0<i2, i0<i3, i1<i2, and i1<i3. 196 static void addAffineOrderings(std::vector<std::vector<bool>> &adjM, 197 AffineExpr a, AffineExpr b, unsigned fidx) { 198 switch (a.getKind()) { 199 case AffineExprKind::DimId: { 200 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 201 if (b) 202 addAffineOrderings(adjM, b, AffineExpr(), idx); 203 else 204 adjM[fidx][idx] = true; 205 break; 206 } 207 case AffineExprKind::Add: 208 case AffineExprKind::Mul: { 209 auto binOp = a.cast<AffineBinaryOpExpr>(); 210 addAffineOrderings(adjM, binOp.getLHS(), b, fidx); 211 addAffineOrderings(adjM, binOp.getRHS(), b, fidx); 212 break; 213 } 214 default: 215 break; 216 } 217 } 218 219 /// Computes a topologically sorted iteration graph for the linalg operation. 220 /// Ensures all tensors are visited in natural index order. This is essential 221 /// for sparse storage formats since these only support access along fixed 222 /// dimensions. Even for dense storage formats, however, the natural index 223 /// order yields innermost unit-stride access with better spatial locality. 224 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op, 225 std::vector<unsigned> &topSort, 226 unsigned mask) { 227 // Set up an n x n from/to adjacency matrix of the iteration graph 228 // for the implicit loop indices i_0 .. i_n-1. 229 unsigned n = op.getNumLoops(); 230 std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false)); 231 232 // Iterate over the indexing maps of every tensor in the tensor expression. 233 for (OpOperand *t : op.getInputAndOutputOperands()) { 234 auto map = op.getTiedIndexingMap(t); 235 auto enc = getSparseTensorEncoding(t->get().getType()); 236 assert(map.getNumDims() == n); 237 // Skip dense tensor constraints when not requested. 238 if (!(mask & SortMask::kIncludeDense) && !enc) 239 continue; 240 // Each tensor expression and optional dimension ordering (row-major 241 // by default) puts an ordering constraint on the loop indices. For 242 // example, the tensor expresion A_ijk forces the ordering i < j < k 243 // on the loop indices if no explicit dimension ordering is given. 244 for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) { 245 AffineExpr f = map.getResult(perm(enc, d - 1)); 246 AffineExpr t = map.getResult(perm(enc, d)); 247 addAffineOrderings(adjM, f, t, 0); 248 } 249 // Push unrelated loops into sparse iteration space, so these 250 // will be skipped more often. 251 if (mask & SortMask::kIncludeUndef) { 252 unsigned tensor = t->getOperandNumber(); 253 for (unsigned i = 0; i < n; i++) 254 if (merger.isDim(tensor, i, Dim::kSparse)) 255 for (unsigned j = 0; j < n; j++) 256 if (merger.isDim(tensor, j, Dim::kUndef)) 257 adjM[i][j] = true; 258 } 259 } 260 261 // Topologically sort the iteration graph to determine loop order. 262 // Report failure for a cyclic iteration graph. 263 topSort.clear(); 264 topSort.reserve(n); 265 std::vector<unsigned> visit(n, 0); 266 for (unsigned i = 0; i < n; i++) 267 if (visit[i] == 0) 268 if (!topSortDFS(i, visit, topSort, adjM)) 269 return false; // cycle! 270 std::reverse(std::begin(topSort), std::end(topSort)); 271 return true; 272 } 273 274 /// Returns true if tensor has an in-place annotation. 275 static bool isInPlace(Value val) { 276 if (auto arg = val.dyn_cast<BlockArgument>()) 277 if (auto funcOp = dyn_cast<FuncOp>(arg.getOwner()->getParentOp())) 278 if (auto attr = funcOp.getArgAttrOfType<BoolAttr>( 279 arg.getArgNumber(), 280 linalg::comprehensive_bufferize::BufferizableOpInterface:: 281 kInplaceableAttrName)) 282 return attr.getValue(); 283 return false; 284 } 285 286 /// Returns true if tensor materializes uninitialized into the computation. 287 static bool isMaterializing(Value val) { 288 return val.getDefiningOp<linalg::InitTensorOp>() || 289 val.getDefiningOp<InitOp>(); 290 } 291 292 /// Returns true when the tensor expression is admissable for codegen. 293 /// Since all sparse input tensors are admissable, we just need to check 294 /// whether the out tensor in the tensor expression codegen is admissable. 295 /// Sets `sparseOut` to the tensor and `outerParNest` to the outer injective 296 /// nesting depth when a "truly dynamic" sparse tensor output occurs. 297 static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op, 298 std::vector<unsigned> &topSort, unsigned exp, 299 OpOperand **sparseOut, 300 unsigned &outerParNest) { 301 OpOperand *lhs = op.getOutputOperand(0); 302 unsigned tensor = lhs->getOperandNumber(); 303 auto enc = getSparseTensorEncoding(lhs->get().getType()); 304 // An non-annotated output tensor is assumed dense, and becomes a random 305 // access n-dim memref. Admissable since insertions cannot occur. 306 if (!enc) 307 return true; 308 // An all-dense annotated "sparse" output tensor becomes a linearized random 309 // access 1-dim memref. Also admissable since insertions cannot occur. 310 bool allDense = true; 311 auto iteratorTypes = op.iterator_types().getValue(); 312 unsigned numLoops = iteratorTypes.size(); 313 for (unsigned i = 0; i < numLoops; i++) 314 if (merger.isDim(tensor, i, Dim::kSparse)) { 315 allDense = false; 316 break; 317 } 318 if (allDense) 319 return true; 320 // A tensor expression with a sparse output tensor that changes its values 321 // but not its nonzero structure, an operation called "simply dynamic" in 322 // [Bik96,Ch9], is also admissable without special codegen, provided 323 // the tensor's underlying sparse storage scheme can be modified in place. 324 if (merger.isSingleCondition(tensor, exp) && isInPlace(lhs->get())) 325 return true; 326 // Accept "truly dynamic" if the output tensor materializes uninitialized 327 // into the computation and insertions occur in lexicographic index order. 328 if (isMaterializing(lhs->get())) { 329 unsigned nest = 0; 330 for (unsigned i = 0; i < numLoops; i++) { 331 if (isReductionIterator(iteratorTypes[topSort[i]])) 332 break; // terminate at first reduction 333 nest++; 334 } 335 // Determine admissable dynamic insertion situations: 336 // (1) fully injective, since there are no reductions, 337 // (2) admissable 1-d expansion in innermost dimension. TODO: accept 338 if (nest == op.getRank(lhs)) { 339 *sparseOut = lhs; 340 outerParNest = nest; 341 return true; 342 } 343 } 344 return false; 345 } 346 347 //===----------------------------------------------------------------------===// 348 // Sparse compiler synthesis methods (reductions). 349 //===----------------------------------------------------------------------===// 350 351 /// Maps reduction kind to name encoding. 352 static StringRef getReductionName(Reduction kind) { 353 switch (kind) { 354 case kNoReduc: 355 break; 356 case kSum: 357 return "add"; 358 case kProduct: 359 return "mul"; 360 case kAnd: 361 return "and"; 362 case kOr: 363 return "or"; 364 case kXor: 365 return "xor"; 366 } 367 llvm_unreachable("unknown reduction kind"); 368 } 369 370 /// Maps operation to reduction. 371 static Reduction getReduction(Kind kind) { 372 switch (kind) { 373 case Kind::kAddF: 374 case Kind::kAddI: 375 case Kind::kSubF: 376 case Kind::kSubI: 377 return kSum; 378 case Kind::kMulF: 379 case Kind::kMulI: 380 return kProduct; 381 case Kind::kAndI: 382 return kAnd; 383 case Kind::kOrI: 384 return kOr; 385 case Kind::kXorI: 386 return kXor; 387 default: 388 llvm_unreachable("unexpected reduction operator"); 389 } 390 } 391 392 /// Generates an initial value for a vector reduction, following the scheme 393 /// given in Chapter 5 of "The Software Vectorization Handbook", where the 394 /// initial scalar value is correctly embedded in the vector reduction value, 395 /// and a straightforward horizontal reduction will complete the operation. 396 static Value genVectorReducInit(CodeGen &codegen, PatternRewriter &rewriter, 397 Location loc, VectorType vtp) { 398 Value r = codegen.redVal; 399 switch (codegen.redKind) { 400 case kNoReduc: 401 break; 402 case kSum: 403 case kXor: { 404 // Initialize reduction vector to: | 0 | .. | 0 | r | 405 Attribute zero = rewriter.getZeroAttr(vtp); 406 Value vec = rewriter.create<arith::ConstantOp>(loc, vtp, zero); 407 return rewriter.create<vector::InsertElementOp>( 408 loc, r, vec, rewriter.create<arith::ConstantIndexOp>(loc, 0)); 409 } 410 case kProduct: { 411 // Initialize reduction vector to: | 1 | .. | 1 | r | 412 Type etp = vtp.getElementType(); 413 Attribute one; 414 if (etp.isa<FloatType>()) 415 one = rewriter.getFloatAttr(etp, 1.0); 416 else 417 one = rewriter.getIntegerAttr(etp, 1); 418 Value vec = rewriter.create<arith::ConstantOp>( 419 loc, vtp, DenseElementsAttr::get(vtp, one)); 420 return rewriter.create<vector::InsertElementOp>( 421 loc, r, vec, rewriter.create<arith::ConstantIndexOp>(loc, 0)); 422 } 423 case kAnd: 424 case kOr: 425 // Initialize reduction vector to: | r | .. | r | r | 426 return rewriter.create<vector::BroadcastOp>(loc, vtp, r); 427 } 428 llvm_unreachable("unknown reduction kind"); 429 } 430 431 /// Generates final value for a vector reduction. 432 static Value genVectorReducEnd(CodeGen &codegen, PatternRewriter &rewriter, 433 Location loc, VectorType vtp) { 434 StringRef name = getReductionName(codegen.redKind); 435 StringAttr kind = rewriter.getStringAttr(name); 436 return rewriter.create<vector::ReductionOp>(loc, vtp.getElementType(), kind, 437 codegen.redVal, ValueRange{}); 438 } 439 440 /// Updates scalarized reduction value. 441 static void updateReduc(Merger &merger, CodeGen &codegen, Value reduc) { 442 assert(codegen.redKind != kNoReduc); 443 codegen.redVal = merger.exp(codegen.redExp).val = reduc; 444 } 445 446 //===----------------------------------------------------------------------===// 447 // Sparse compiler synthesis methods (statements and expressions). 448 //===----------------------------------------------------------------------===// 449 450 /// Maps sparse integer option to actual integral storage type. 451 static Type genIntType(PatternRewriter &rewriter, unsigned width) { 452 if (width == 0) 453 return rewriter.getIndexType(); 454 return rewriter.getIntegerType(width); 455 } 456 457 /// Generates buffer for the output tensor. Note that all sparse kernels 458 /// assume that when all elements are written to (viz. x(i) = y(i) * z(i)), 459 /// the output buffer is already initialized to all zeroes and only nonzeroes 460 /// values are computed and written out. For updates (viz. x(i) += y(i) * z(i)), 461 /// only nonzeroes values are used for the updates and no assumption on the 462 /// original contents of the output buffer is necessary.. 463 static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter, 464 linalg::GenericOp op, MemRefType denseTp, 465 ArrayRef<Value> args) { 466 Location loc = op.getLoc(); 467 Value tensor = op.getOutputOperand(0)->get(); 468 // The output tensor simply could materialize from the buffer that will 469 // be generated for the tensor present in the outs() clause. This has 470 // the major advantage that the sparse kernel only updates the nonzero 471 // positions for the output tensor. 472 if (isInPlace(tensor)) 473 return rewriter.create<bufferization::ToMemrefOp>(loc, denseTp, tensor); 474 // By default, a new buffer is allocated which is initialized to the 475 // tensor defined in the outs() clause. This is always correct but 476 // introduces a dense initialization component that may negatively 477 // impact the running complexity of the sparse kernel. If the tensor 478 // materializes into the computation, we need to preserve the zero 479 // initialization assumption of all sparse output buffers. 480 if (isMaterializing(tensor)) { 481 Type tp = denseTp.getElementType(); 482 Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args); 483 Value zero = 484 rewriter.create<arith::ConstantOp>(loc, tp, rewriter.getZeroAttr(tp)); 485 rewriter.create<linalg::FillOp>(loc, zero, alloc); 486 return alloc; 487 } 488 Value init = rewriter.create<bufferization::ToMemrefOp>(loc, denseTp, tensor); 489 Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args); 490 rewriter.create<memref::CopyOp>(loc, init, alloc); 491 return alloc; 492 } 493 494 /// Local bufferization of all dense and sparse data structures. 495 /// This code enables testing the first prototype sparse compiler. 496 // TODO: replace this with a proliferated bufferization strategy 497 static void genBuffers(Merger &merger, CodeGen &codegen, 498 PatternRewriter &rewriter, linalg::GenericOp op) { 499 Location loc = op.getLoc(); 500 assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1); 501 // For every tensor, find lower and upper bound on dimensions, set the 502 // same bounds on loop indices, and obtain dense or sparse buffer(s). 503 SmallVector<Value, 4> args; 504 for (OpOperand *t : op.getInputAndOutputOperands()) { 505 unsigned tensor = t->getOperandNumber(); 506 auto shape = op.getShape(t); 507 auto map = op.getTiedIndexingMap(t); 508 auto enc = getSparseTensorEncoding(t->get().getType()); 509 // Scan all dimensions of current tensor. 510 args.clear(); 511 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 512 AffineExpr a = map.getResult(perm(enc, d)); 513 if (a.getKind() != AffineExprKind::DimId) 514 continue; // compound 515 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 516 // Handle sparse storage schemes. 517 if (merger.isDim(tensor, idx, Dim::kSparse)) { 518 auto dynShape = {ShapedType::kDynamicSize}; 519 auto ptrTp = MemRefType::get( 520 dynShape, genIntType(rewriter, enc.getPointerBitWidth())); 521 auto indTp = MemRefType::get( 522 dynShape, genIntType(rewriter, enc.getIndexBitWidth())); 523 Value dim = rewriter.create<arith::ConstantIndexOp>(loc, d); 524 // Generate sparse primitives to obtains pointer and indices. 525 codegen.pointers[tensor][idx] = 526 rewriter.create<ToPointersOp>(loc, ptrTp, t->get(), dim); 527 codegen.indices[tensor][idx] = 528 rewriter.create<ToIndicesOp>(loc, indTp, t->get(), dim); 529 } 530 // Find upper bound in current dimension. 531 unsigned p = perm(enc, d); 532 Value up = linalg::createOrFoldDimOp(rewriter, loc, t->get(), p); 533 if (shape[p] == MemRefType::kDynamicSize) 534 args.push_back(up); 535 assert(codegen.highs[tensor][idx] == nullptr); 536 codegen.sizes[idx] = codegen.highs[tensor][idx] = up; 537 } 538 // Perform the required bufferization. Dense inputs materialize 539 // from the input tensors. Dense outputs need special handling. 540 // Sparse inputs use sparse primitives to obtain the values. 541 // We also accept in-place all-dense annotated "sparse" outputs. 542 Type elementType = getElementTypeOrSelf(t->get().getType()); 543 if (!enc) { 544 // Non-annotated dense tensors. 545 auto denseTp = MemRefType::get(shape, elementType); 546 if (tensor < op.getNumInputs()) 547 codegen.buffers[tensor] = 548 rewriter.create<bufferization::ToMemrefOp>(loc, denseTp, t->get()); 549 else 550 codegen.buffers[tensor] = 551 genOutputBuffer(codegen, rewriter, op, denseTp, args); 552 } else if (t == codegen.sparseOut) { 553 // True sparse output needs a lexIdx array. 554 Value rank = rewriter.create<arith::ConstantIndexOp>(loc, op.getRank(t)); 555 auto dynShape = {ShapedType::kDynamicSize}; 556 auto memTp = MemRefType::get(dynShape, rewriter.getIndexType()); 557 codegen.lexIdx = rewriter.create<memref::AllocaOp>(loc, memTp, rank); 558 } else { 559 // Annotated sparse tensors. 560 auto dynShape = {ShapedType::kDynamicSize}; 561 auto sparseTp = MemRefType::get(dynShape, elementType); 562 codegen.buffers[tensor] = 563 rewriter.create<ToValuesOp>(loc, sparseTp, t->get()); 564 } 565 } 566 } 567 568 /// Constructs vector type. 569 static VectorType vectorType(CodeGen &codegen, Type etp) { 570 return VectorType::get(codegen.curVecLength, etp); 571 } 572 573 /// Constructs vector type from pointer. 574 static VectorType vectorType(CodeGen &codegen, Value ptr) { 575 return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType()); 576 } 577 578 /// Constructs vector iteration mask. 579 static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter, 580 Value iv, Value lo, Value hi, Value step) { 581 Location loc = iv.getLoc(); 582 VectorType mtp = vectorType(codegen, genIntType(rewriter, 1)); 583 // Special case if the vector length evenly divides the trip count (for 584 // example, "for i = 0, 128, 16"). A constant all-true mask is generated 585 // so that all subsequent masked memory operations are immediately folded 586 // into unconditional memory operations. 587 IntegerAttr loInt, hiInt, stepInt; 588 if (matchPattern(lo, m_Constant(&loInt)) && 589 matchPattern(hi, m_Constant(&hiInt)) && 590 matchPattern(step, m_Constant(&stepInt))) { 591 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) 592 return rewriter.create<vector::BroadcastOp>( 593 loc, mtp, rewriter.create<arith::ConstantIntOp>(loc, 1, 1)); 594 } 595 // Otherwise, generate a vector mask that avoids overrunning the upperbound 596 // during vector execution. Here we rely on subsequent loop optimizations to 597 // avoid executing the mask in all iterations, for example, by splitting the 598 // loop into an unconditional vector loop and a scalar cleanup loop. 599 auto minMap = AffineMap::get( 600 /*dimCount=*/2, /*symbolCount=*/1, 601 {rewriter.getAffineSymbolExpr(0), 602 rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)}, 603 rewriter.getContext()); 604 Value end = 605 rewriter.createOrFold<AffineMinOp>(loc, minMap, ValueRange{hi, iv, step}); 606 return rewriter.create<vector::CreateMaskOp>(loc, mtp, end); 607 } 608 609 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi]. 610 static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter, 611 Value ptr, ArrayRef<Value> args) { 612 Location loc = ptr.getLoc(); 613 VectorType vtp = vectorType(codegen, ptr); 614 Value pass = 615 rewriter.create<arith::ConstantOp>(loc, vtp, rewriter.getZeroAttr(vtp)); 616 if (args.back().getType().isa<VectorType>()) { 617 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 618 Value indexVec = args.back(); 619 scalarArgs.back() = rewriter.create<arith::ConstantIndexOp>(loc, 0); 620 return rewriter.create<vector::GatherOp>( 621 loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass); 622 } 623 return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args, 624 codegen.curVecMask, pass); 625 } 626 627 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs. 628 static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter, 629 Value rhs, Value ptr, ArrayRef<Value> args) { 630 Location loc = ptr.getLoc(); 631 if (args.back().getType().isa<VectorType>()) { 632 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 633 Value indexVec = args.back(); 634 scalarArgs.back() = rewriter.create<arith::ConstantIndexOp>(loc, 0); 635 rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, 636 codegen.curVecMask, rhs); 637 return; 638 } 639 rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask, 640 rhs); 641 } 642 643 /// Generates a vectorized invariant. Here we rely on subsequent loop 644 /// optimizations to hoist the invariant broadcast out of the vector loop. 645 static Value genVectorInvariantValue(CodeGen &codegen, 646 PatternRewriter &rewriter, Value val) { 647 VectorType vtp = vectorType(codegen, val.getType()); 648 return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val); 649 } 650 651 /// Generates an affine expression. 652 // 653 // TODO: generalize for sparse tensor subscripts 654 // 655 static Value genAffine(CodeGen &codegen, PatternRewriter &rewriter, 656 AffineExpr a, Location loc) { 657 switch (a.getKind()) { 658 case AffineExprKind::DimId: { 659 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 660 return codegen.loops[idx]; // universal dense index 661 } 662 case AffineExprKind::Add: { 663 auto binOp = a.cast<AffineBinaryOpExpr>(); 664 return rewriter.create<arith::AddIOp>( 665 loc, genAffine(codegen, rewriter, binOp.getLHS(), loc), 666 genAffine(codegen, rewriter, binOp.getRHS(), loc)); 667 } 668 case AffineExprKind::Mul: { 669 auto binOp = a.cast<AffineBinaryOpExpr>(); 670 return rewriter.create<arith::MulIOp>( 671 loc, genAffine(codegen, rewriter, binOp.getLHS(), loc), 672 genAffine(codegen, rewriter, binOp.getRHS(), loc)); 673 } 674 case AffineExprKind::Constant: { 675 int64_t c = a.cast<AffineConstantExpr>().getValue(); 676 return rewriter.create<arith::ConstantIndexOp>(loc, c); 677 } 678 default: 679 llvm_unreachable("unexpected affine subscript"); 680 } 681 } 682 683 /// Generates subscript for load/store on a dense or sparse tensor. 684 static Value genSubscript(CodeGen &codegen, PatternRewriter &rewriter, 685 linalg::GenericOp op, OpOperand *t, 686 SmallVector<Value, 4> &args) { 687 unsigned tensor = t->getOperandNumber(); 688 auto map = op.getTiedIndexingMap(t); 689 auto enc = getSparseTensorEncoding(t->get().getType()); 690 unsigned rank = map.getNumResults(); 691 if (enc) { 692 // Note that currently, all sparse subscripts are simple. 693 // TODO: accept affine too? 694 AffineExpr a = map.getResult(perm(enc, rank - 1)); 695 assert(a.getKind() == AffineExprKind::DimId); 696 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 697 assert(codegen.pidxs[tensor][idx] != nullptr); 698 args.push_back(codegen.pidxs[tensor][idx]); // position index 699 } else { 700 for (unsigned d = 0; d < rank; d++) { 701 AffineExpr a = map.getResult(perm(enc, d)); 702 args.push_back(genAffine(codegen, rewriter, a, op.getLoc())); 703 } 704 } 705 return codegen.buffers[tensor]; 706 } 707 708 /// Generates a load on a dense or sparse tensor. 709 static Value genTensorLoad(Merger &merger, CodeGen &codegen, 710 PatternRewriter &rewriter, linalg::GenericOp op, 711 unsigned exp) { 712 // Test if the load was hoisted to a higher loop nest. 713 Value val = merger.exp(exp).val; 714 if (val) { 715 if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>()) 716 return genVectorInvariantValue(codegen, rewriter, val); 717 return val; 718 } 719 // Insertion (a sparse tensor output "loads" as zero). 720 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 721 if (t == codegen.sparseOut) { 722 Type tp = getElementTypeOrSelf(t->get().getType()); 723 return rewriter.create<arith::ConstantOp>(op.getLoc(), tp, 724 rewriter.getZeroAttr(tp)); 725 } 726 // Actual load. 727 SmallVector<Value, 4> args; 728 Value ptr = genSubscript(codegen, rewriter, op, t, args); 729 if (codegen.curVecLength > 1) 730 return genVectorLoad(codegen, rewriter, ptr, args); 731 return rewriter.create<memref::LoadOp>(op.getLoc(), ptr, args); 732 } 733 734 /// Generates a store on a dense or sparse tensor. 735 static void genTensorStore(Merger &merger, CodeGen &codegen, 736 PatternRewriter &rewriter, linalg::GenericOp op, 737 Value rhs) { 738 Location loc = op.getLoc(); 739 // Test if this is a scalarized reduction. 740 if (codegen.redVal) { 741 if (codegen.curVecLength > 1) 742 rhs = rewriter.create<SelectOp>(loc, codegen.curVecMask, rhs, 743 codegen.redVal); 744 updateReduc(merger, codegen, rhs); 745 return; 746 } 747 // Insertion. 748 OpOperand *t = op.getOutputOperand(0); 749 if (t == codegen.sparseOut) { 750 rewriter.create<LexInsertOp>(loc, t->get(), codegen.lexIdx, rhs); 751 return; 752 } 753 // Actual store. 754 SmallVector<Value, 4> args; 755 Value ptr = genSubscript(codegen, rewriter, op, t, args); 756 if (codegen.curVecLength > 1) 757 genVectorStore(codegen, rewriter, rhs, ptr, args); 758 else 759 rewriter.create<memref::StoreOp>(loc, rhs, ptr, args); 760 } 761 762 /// Generates a pointer/index load from the sparse storage scheme. Narrower 763 /// data types need to be zero extended before casting the value into the 764 /// index type used for looping and indexing. 765 static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc, 766 Value ptr, Value s) { 767 // See https://llvm.org/docs/GetElementPtr.html for some background on 768 // the complications described below. 769 if (codegen.curVecLength > 1) { 770 // Since the index vector is used in a subsequent gather/scatter operations, 771 // which effectively defines an unsigned pointer + signed index, we must 772 // zero extend the vector to an index width. For 8-bit and 16-bit values, 773 // an 32-bit index width suffices. For 32-bit values, zero extending the 774 // elements into 64-bit loses some performance since the 32-bit indexed 775 // gather/scatter is more efficient than the 64-bit index variant (if the 776 // negative 32-bit index space is unused, the enableSIMDIndex32 flag can 777 // preserve this performance). For 64-bit values, there is no good way 778 // to state that the indices are unsigned, with creates the potential of 779 // incorrect address calculations in the unlikely case we need such 780 // extremely large offsets. 781 Type etp = ptr.getType().cast<MemRefType>().getElementType(); 782 Value vload = genVectorLoad(codegen, rewriter, ptr, {s}); 783 if (!etp.isa<IndexType>()) { 784 if (etp.getIntOrFloatBitWidth() < 32) 785 vload = rewriter.create<arith::ExtUIOp>( 786 loc, vload, vectorType(codegen, genIntType(rewriter, 32))); 787 else if (etp.getIntOrFloatBitWidth() < 64 && 788 !codegen.options.enableSIMDIndex32) 789 vload = rewriter.create<arith::ExtUIOp>( 790 loc, vload, vectorType(codegen, genIntType(rewriter, 64))); 791 } 792 return vload; 793 } 794 // For the scalar case, we simply zero extend narrower indices into 64-bit 795 // values before casting to index without a performance penalty. Here too, 796 // however, indices that already are 64-bit, in theory, cannot express the 797 // full range as explained above. 798 Value load = rewriter.create<memref::LoadOp>(loc, ptr, s); 799 if (!load.getType().isa<IndexType>()) { 800 if (load.getType().getIntOrFloatBitWidth() < 64) 801 load = 802 rewriter.create<arith::ExtUIOp>(loc, load, genIntType(rewriter, 64)); 803 load = 804 rewriter.create<arith::IndexCastOp>(loc, load, rewriter.getIndexType()); 805 } 806 return load; 807 } 808 809 /// Generates an invariant value. 810 static Value genInvariantValue(Merger &merger, CodeGen &codegen, 811 PatternRewriter &rewriter, unsigned exp) { 812 Value val = merger.exp(exp).val; 813 if (codegen.curVecLength > 1) 814 return genVectorInvariantValue(codegen, rewriter, val); 815 return val; 816 } 817 818 /// Generates an address computation "sz * p + i". 819 static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter, 820 Location loc, Value size, Value p, Value i) { 821 Value mul = rewriter.create<arith::MulIOp>(loc, size, p); 822 if (auto vtp = i.getType().dyn_cast<VectorType>()) { 823 Value inv = 824 rewriter.create<arith::IndexCastOp>(loc, mul, vtp.getElementType()); 825 mul = genVectorInvariantValue(codegen, rewriter, inv); 826 } 827 return rewriter.create<arith::AddIOp>(loc, mul, i); 828 } 829 830 /// Recursively generates tensor expression. 831 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 832 linalg::GenericOp op, unsigned exp) { 833 Location loc = op.getLoc(); 834 if (exp == -1u) 835 return Value(); 836 if (merger.exp(exp).kind == Kind::kTensor) 837 return genTensorLoad(merger, codegen, rewriter, op, exp); 838 if (merger.exp(exp).kind == Kind::kInvariant) 839 return genInvariantValue(merger, codegen, rewriter, exp); 840 Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0); 841 Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1); 842 return merger.buildExp(rewriter, loc, exp, v0, v1); 843 } 844 845 /// Determines if affine expression is invariant. 846 static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a, 847 unsigned ldx, bool &atLevel) { 848 switch (a.getKind()) { 849 case AffineExprKind::DimId: { 850 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 851 if (idx == ldx) 852 atLevel = true; 853 return codegen.loops[idx] != nullptr; // no longer in play? 854 } 855 case AffineExprKind::Add: 856 case AffineExprKind::Mul: { 857 auto binOp = a.cast<AffineBinaryOpExpr>(); 858 return isInvariantAffine(codegen, binOp.getLHS(), ldx, atLevel) && 859 isInvariantAffine(codegen, binOp.getRHS(), ldx, atLevel); 860 } 861 default: 862 return true; 863 } 864 } 865 866 /// Hoists loop invariant tensor loads for which indices have been exhausted. 867 static void genInvariants(Merger &merger, CodeGen &codegen, 868 PatternRewriter &rewriter, linalg::GenericOp op, 869 unsigned exp, unsigned ldx, bool atStart, 870 Kind last = Kind::kTensor) { 871 if (exp == -1u) 872 return; 873 if (merger.exp(exp).kind == Kind::kTensor) { 874 // Inspect tensor indices. 875 bool atLevel = ldx == -1u; 876 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 877 auto map = op.getTiedIndexingMap(t); 878 auto enc = getSparseTensorEncoding(t->get().getType()); 879 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 880 AffineExpr a = map.getResult(perm(enc, d)); 881 if (!isInvariantAffine(codegen, a, ldx, atLevel)) 882 return; // still in play 883 } 884 // All exhausted at this level (atLevel denotes exactly at this level). 885 if (!atLevel) 886 return; 887 OpOperand *lhs = op.getOutputOperand(0); 888 if (lhs == t) { 889 // Start or end a scalarized reduction 890 if (atStart) { 891 Value load = genTensorLoad(merger, codegen, rewriter, op, exp); 892 codegen.redKind = getReduction(last); 893 codegen.redExp = exp; 894 updateReduc(merger, codegen, load); 895 } else { 896 Value redVal = codegen.redVal; 897 updateReduc(merger, codegen, Value()); 898 codegen.redExp = -1u; 899 codegen.redKind = kNoReduc; 900 genTensorStore(merger, codegen, rewriter, op, redVal); 901 } 902 } else { 903 // Start or end loop invariant hoisting of a tensor load. 904 merger.exp(exp).val = 905 atStart ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value(); 906 } 907 } else if (merger.exp(exp).kind != Kind::kInvariant) { 908 // Traverse into the binary operations. Note that we only hoist 909 // tensor loads, since subsequent MLIR/LLVM passes know how to 910 // deal with all other kinds of derived loop invariants. 911 Kind last = merger.exp(exp).kind; 912 unsigned e0 = merger.exp(exp).children.e0; 913 unsigned e1 = merger.exp(exp).children.e1; 914 genInvariants(merger, codegen, rewriter, op, e0, ldx, atStart, last); 915 genInvariants(merger, codegen, rewriter, op, e1, ldx, atStart, last); 916 } 917 } 918 919 /// Generates initialization code for the subsequent loop sequence at 920 /// current index level. Returns true if the loop sequence needs to 921 /// maintain the universal index. 922 static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 923 linalg::GenericOp op, std::vector<unsigned> &topSort, 924 unsigned at, llvm::BitVector &inits) { 925 bool needsUniv = false; 926 Location loc = op.getLoc(); 927 unsigned idx = topSort[at]; 928 929 // Initialize sparse positions. 930 for (unsigned b = 0, be = inits.size(); b < be; b++) { 931 if (inits[b]) { 932 unsigned tensor = merger.tensor(b); 933 assert(idx == merger.index(b)); 934 if (merger.isDim(b, Dim::kSparse)) { 935 // Initialize sparse index. 936 unsigned pat = at; 937 for (; pat != 0; pat--) { 938 if (codegen.pidxs[tensor][topSort[pat - 1]]) 939 break; 940 } 941 Value ptr = codegen.pointers[tensor][idx]; 942 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 943 Value p0 = (pat == 0) ? rewriter.create<arith::ConstantIndexOp>(loc, 0) 944 : codegen.pidxs[tensor][topSort[pat - 1]]; 945 codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0); 946 Value p1 = rewriter.create<arith::AddIOp>(loc, p0, one); 947 codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1); 948 } else { 949 // Dense index still in play. 950 needsUniv = true; 951 } 952 } 953 } 954 955 // Initialize the universal dense index. 956 codegen.loops[idx] = rewriter.create<arith::ConstantIndexOp>(loc, 0); 957 return needsUniv; 958 } 959 960 /// Returns vectorization strategy. Any implicit inner loop in the Linalg 961 /// operation is a candidate. Whether it is actually converted to SIMD code 962 /// depends on the requested strategy. 963 static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) { 964 switch (codegen.options.vectorizationStrategy) { 965 case SparseVectorizationStrategy::kNone: 966 return false; 967 case SparseVectorizationStrategy::kDenseInnerLoop: 968 return isInner && !isSparse; 969 case SparseVectorizationStrategy::kAnyStorageInnerLoop: 970 return isInner; 971 } 972 llvm_unreachable("unexpected vectorization strategy"); 973 } 974 975 /// Returns parallelization strategy. Any implicit loop in the Linalg operation 976 /// that is marked "parallel" is a candidate. Whether it is actually converted 977 /// to a parallel operation depends on the requested strategy. 978 static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction, 979 bool isSparse, bool isVector) { 980 switch (codegen.options.parallelizationStrategy) { 981 case SparseParallelizationStrategy::kNone: 982 return false; 983 case SparseParallelizationStrategy::kDenseOuterLoop: 984 return isOuter && !isSparse && !isReduction && !isVector; 985 case SparseParallelizationStrategy::kAnyStorageOuterLoop: 986 return isOuter && !isReduction && !isVector; 987 case SparseParallelizationStrategy::kDenseAnyLoop: 988 return !isSparse && !isReduction && !isVector; 989 case SparseParallelizationStrategy::kAnyStorageAnyLoop: 990 return !isReduction && !isVector; 991 } 992 llvm_unreachable("unexpected parallelization strategy"); 993 } 994 995 /// Checks unit stride for dense tensors. The iteration graph may have ignored 996 /// dense access patterns in order to avoid cycles (sparse access patterns are 997 /// always placed innermost), but that means dense access has become strided. 998 /// This prevents effective vectorization. 999 static bool denseUnitStrides(Merger &merger, linalg::GenericOp op, 1000 unsigned idx) { 1001 for (OpOperand *t : op.getInputAndOutputOperands()) { 1002 if (!getSparseTensorEncoding(t->get().getType())) { 1003 auto map = op.getTiedIndexingMap(t); 1004 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 1005 AffineExpr a = map.getResult(d); 1006 // Report non-unit stride if innermost index appears at an outer 1007 // dimension (true non-unit stride) or if the innermost index appears 1008 // in a compound subscript in the innermost dimension. Even if the 1009 // latter is unit stride, it does not play well with scatter/gather. 1010 // TODO: accept unit stride affine innermost like a[i,j+k+1]? 1011 if (a.isFunctionOfDim(idx) && 1012 ((d != rank - 1) || (a.getKind() != AffineExprKind::DimId))) 1013 return false; 1014 } 1015 } 1016 } 1017 return true; 1018 } 1019 1020 /// Generates a for-loop on a single index. 1021 static Operation *genFor(Merger &merger, CodeGen &codegen, 1022 PatternRewriter &rewriter, linalg::GenericOp op, 1023 bool isOuter, bool isInner, unsigned idx, 1024 llvm::BitVector &indices) { 1025 unsigned fb = indices.find_first(); 1026 unsigned tensor = merger.tensor(fb); 1027 assert(idx == merger.index(fb)); 1028 auto iteratorTypes = op.iterator_types().getValue(); 1029 bool isReduction = isReductionIterator(iteratorTypes[idx]); 1030 bool isSparse = merger.isDim(fb, Dim::kSparse); 1031 bool isVector = !codegen.sparseOut && 1032 isVectorFor(codegen, isInner, isSparse) && 1033 denseUnitStrides(merger, op, idx); 1034 bool isParallel = 1035 !codegen.sparseOut && 1036 isParallelFor(codegen, isOuter, isReduction, isSparse, isVector); 1037 1038 // Prepare vector length. 1039 if (isVector) 1040 codegen.curVecLength = codegen.options.vectorLength; 1041 1042 // Loop bounds and increment. 1043 Location loc = op.getLoc(); 1044 Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx]; 1045 Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx]; 1046 Value step = 1047 rewriter.create<arith::ConstantIndexOp>(loc, codegen.curVecLength); 1048 1049 // Emit a parallel loop. 1050 if (isParallel) { 1051 assert(!isVector); 1052 scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step); 1053 if (isSparse) 1054 codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0]; 1055 else 1056 codegen.loops[idx] = parOp.getInductionVars()[0]; 1057 rewriter.setInsertionPointToStart(parOp.getBody()); 1058 return parOp; 1059 } 1060 1061 // Emit a sequential or vector loop. 1062 SmallVector<Value, 4> operands; 1063 if (codegen.redVal) { 1064 // In a vector loop, bring reduction into SIMD form, if not already. 1065 if (isVector && !codegen.redVal.getType().isa<VectorType>()) { 1066 VectorType vtp = vectorType(codegen, codegen.redVal.getType()); 1067 Value vred = genVectorReducInit(codegen, rewriter, loc, vtp); 1068 updateReduc(merger, codegen, vred); 1069 } 1070 operands.push_back(codegen.redVal); 1071 } 1072 scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands); 1073 if (codegen.redVal) 1074 updateReduc(merger, codegen, forOp.getRegionIterArgs().front()); 1075 // Assign induction variable to sparse or dense index. 1076 Value iv = forOp.getInductionVar(); 1077 if (isSparse) 1078 codegen.pidxs[tensor][idx] = iv; 1079 else 1080 codegen.loops[idx] = iv; 1081 rewriter.setInsertionPointToStart(forOp.getBody()); 1082 // Share vector iteration mask between all subsequent loads/stores. 1083 if (isVector) 1084 codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step); 1085 return forOp; 1086 } 1087 1088 /// Emit a while-loop for co-iteration over multiple indices. 1089 static Operation *genWhile(Merger &merger, CodeGen &codegen, 1090 PatternRewriter &rewriter, linalg::GenericOp op, 1091 unsigned idx, bool needsUniv, 1092 llvm::BitVector &indices) { 1093 SmallVector<Type, 4> types; 1094 SmallVector<Value, 4> operands; 1095 // Construct the while-loop with a parameter for each index. 1096 Type indexType = rewriter.getIndexType(); 1097 for (unsigned b = 0, be = indices.size(); b < be; b++) { 1098 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1099 unsigned tensor = merger.tensor(b); 1100 assert(idx == merger.index(b)); 1101 types.push_back(indexType); 1102 operands.push_back(codegen.pidxs[tensor][idx]); 1103 } 1104 } 1105 if (codegen.redVal) { 1106 types.push_back(codegen.redVal.getType()); 1107 operands.push_back(codegen.redVal); 1108 } 1109 if (needsUniv) { 1110 types.push_back(indexType); 1111 operands.push_back(codegen.loops[idx]); 1112 } 1113 assert(types.size() == operands.size()); 1114 Location loc = op.getLoc(); 1115 scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands); 1116 Block *before = rewriter.createBlock(&whileOp.before(), {}, types); 1117 Block *after = rewriter.createBlock(&whileOp.after(), {}, types); 1118 1119 // Build the "before" region, which effectively consists 1120 // of a conjunction of "i < upper" tests on all induction. 1121 rewriter.setInsertionPointToStart(&whileOp.before().front()); 1122 Value cond; 1123 unsigned o = 0; 1124 for (unsigned b = 0, be = indices.size(); b < be; b++) { 1125 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1126 unsigned tensor = merger.tensor(b); 1127 assert(idx == merger.index(b)); 1128 Value op1 = before->getArgument(o); 1129 Value op2 = codegen.highs[tensor][idx]; 1130 Value opc = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 1131 op1, op2); 1132 cond = cond ? rewriter.create<arith::AndIOp>(loc, cond, opc) : opc; 1133 codegen.pidxs[tensor][idx] = after->getArgument(o++); 1134 } 1135 } 1136 if (codegen.redVal) 1137 updateReduc(merger, codegen, after->getArgument(o++)); 1138 if (needsUniv) 1139 codegen.loops[idx] = after->getArgument(o++); 1140 assert(o == operands.size()); 1141 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 1142 rewriter.setInsertionPointToStart(&whileOp.after().front()); 1143 return whileOp; 1144 } 1145 1146 /// Generates a for-loop or a while-loop, depending on whether it implements 1147 /// singleton iteration or co-iteration over the given conjunction. 1148 static Operation *genLoop(Merger &merger, CodeGen &codegen, 1149 PatternRewriter &rewriter, linalg::GenericOp op, 1150 std::vector<unsigned> &topSort, unsigned at, 1151 bool needsUniv, llvm::BitVector &indices) { 1152 unsigned idx = topSort[at]; 1153 if (indices.count() == 1) { 1154 bool isOuter = at == 0; 1155 bool isInner = at == topSort.size() - 1; 1156 return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx, 1157 indices); 1158 } 1159 return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices); 1160 } 1161 1162 /// Generates the local variables for this loop, consisting of the sparse 1163 /// indices, restored universal dense index, and dense positions. 1164 static void genLocals(Merger &merger, CodeGen &codegen, 1165 PatternRewriter &rewriter, linalg::GenericOp op, 1166 std::vector<unsigned> &topSort, unsigned at, 1167 bool needsUniv, llvm::BitVector &locals) { 1168 Location loc = op.getLoc(); 1169 unsigned idx = topSort[at]; 1170 1171 // Initialize sparse indices. 1172 Value min; 1173 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1174 if (locals[b] && merger.isDim(b, Dim::kSparse)) { 1175 unsigned tensor = merger.tensor(b); 1176 assert(idx == merger.index(b)); 1177 Value ptr = codegen.indices[tensor][idx]; 1178 Value s = codegen.pidxs[tensor][idx]; 1179 Value load = genLoad(codegen, rewriter, loc, ptr, s); 1180 codegen.idxs[tensor][idx] = load; 1181 if (!needsUniv) { 1182 if (min) { 1183 Value cmp = rewriter.create<arith::CmpIOp>( 1184 loc, arith::CmpIPredicate::ult, load, min); 1185 min = rewriter.create<SelectOp>(loc, cmp, load, min); 1186 } else { 1187 min = load; 1188 } 1189 } 1190 } 1191 } 1192 1193 // Merge dense universal index over minimum. 1194 if (min) { 1195 assert(!needsUniv); 1196 codegen.loops[idx] = min; 1197 } 1198 1199 // Initialize dense positions. Note that we generate dense indices of the 1200 // output tensor unconditionally, since they may not appear in the lattice, 1201 // but may be needed for linearized codegen. 1202 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1203 if ((locals[b] || merger.isOutTensor(b, idx)) && 1204 merger.isDim(b, Dim::kDense)) { 1205 unsigned tensor = merger.tensor(b); 1206 assert(idx == merger.index(b)); 1207 unsigned pat = at; 1208 for (; pat != 0; pat--) 1209 if (codegen.pidxs[tensor][topSort[pat - 1]]) 1210 break; 1211 Value p = (pat == 0) ? rewriter.create<arith::ConstantIndexOp>(loc, 0) 1212 : codegen.pidxs[tensor][topSort[pat - 1]]; 1213 codegen.pidxs[tensor][idx] = genAddress( 1214 codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]); 1215 } 1216 } 1217 1218 // Move the insertion indices in lexicographic index order. 1219 if (codegen.sparseOut) { 1220 Value pos = rewriter.create<arith::ConstantIndexOp>(loc, at); 1221 rewriter.create<memref::StoreOp>(loc, codegen.loops[idx], codegen.lexIdx, 1222 pos); 1223 } 1224 } 1225 1226 /// Generates the induction structure for a while-loop. 1227 static void genWhileInduction(Merger &merger, CodeGen &codegen, 1228 PatternRewriter &rewriter, linalg::GenericOp op, 1229 unsigned idx, bool needsUniv, 1230 llvm::BitVector &induction, 1231 scf::WhileOp whileOp) { 1232 Location loc = op.getLoc(); 1233 // Finalize each else branch of all if statements. 1234 if (codegen.redVal) { 1235 while (auto ifOp = dyn_cast_or_null<scf::IfOp>( 1236 rewriter.getInsertionBlock()->getParentOp())) { 1237 rewriter.create<scf::YieldOp>(loc, codegen.redVal); 1238 updateReduc(merger, codegen, ifOp.getResult(0)); 1239 rewriter.setInsertionPointAfter(ifOp); 1240 } 1241 } 1242 rewriter.setInsertionPointToEnd(&whileOp.after().front()); 1243 // Finalize the induction. Note that the induction could be performed 1244 // in the individual if-branches to avoid re-evaluating the conditions. 1245 // However, that would result in a rather elaborate forest of yield 1246 // instructions during code generation. Moreover, performing the induction 1247 // after the if-statements more closely resembles code generated by TACO. 1248 unsigned o = 0; 1249 SmallVector<Value, 4> operands; 1250 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 1251 for (unsigned b = 0, be = induction.size(); b < be; b++) { 1252 if (induction[b] && merger.isDim(b, Dim::kSparse)) { 1253 unsigned tensor = merger.tensor(b); 1254 assert(idx == merger.index(b)); 1255 Value op1 = codegen.idxs[tensor][idx]; 1256 Value op2 = codegen.loops[idx]; 1257 Value op3 = codegen.pidxs[tensor][idx]; 1258 Value cmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 1259 op1, op2); 1260 Value add = rewriter.create<arith::AddIOp>(loc, op3, one); 1261 operands.push_back(rewriter.create<SelectOp>(loc, cmp, add, op3)); 1262 codegen.pidxs[tensor][idx] = whileOp->getResult(o++); 1263 } 1264 } 1265 if (codegen.redVal) { 1266 operands.push_back(codegen.redVal); 1267 updateReduc(merger, codegen, whileOp->getResult(o++)); 1268 } 1269 if (needsUniv) { 1270 operands.push_back( 1271 rewriter.create<arith::AddIOp>(loc, codegen.loops[idx], one)); 1272 codegen.loops[idx] = whileOp->getResult(o++); 1273 } 1274 assert(o == operands.size()); 1275 rewriter.create<scf::YieldOp>(loc, operands); 1276 rewriter.setInsertionPointAfter(whileOp); 1277 } 1278 1279 /// Generates the induction structure for a for-loop. 1280 static void genForInduction(Merger &merger, CodeGen &codegen, 1281 PatternRewriter &rewriter, linalg::GenericOp op, 1282 Operation *loop) { 1283 Location loc = op.getLoc(); 1284 unsigned o = 0; 1285 SmallVector<Value, 4> operands; 1286 if (codegen.redVal) { 1287 operands.push_back(codegen.redVal); 1288 updateReduc(merger, codegen, loop->getResult(o++)); 1289 } 1290 assert(o == operands.size()); 1291 if (o > 0) 1292 rewriter.create<scf::YieldOp>(loc, operands); 1293 rewriter.setInsertionPointAfter(loop); 1294 } 1295 1296 /// Generates a single if-statement within a while-loop. 1297 static scf::IfOp genIf(Merger &merger, CodeGen &codegen, 1298 PatternRewriter &rewriter, linalg::GenericOp op, 1299 unsigned idx, llvm::BitVector &conditions) { 1300 Location loc = op.getLoc(); 1301 SmallVector<Type, 4> types; 1302 Value cond; 1303 for (unsigned b = 0, be = conditions.size(); b < be; b++) { 1304 if (conditions[b]) { 1305 unsigned tensor = merger.tensor(b); 1306 assert(idx == merger.index(b)); 1307 Value clause; 1308 if (merger.isDim(b, Dim::kSparse)) { 1309 Value op1 = codegen.idxs[tensor][idx]; 1310 Value op2 = codegen.loops[idx]; 1311 clause = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 1312 op1, op2); 1313 } else { 1314 clause = rewriter.create<arith::ConstantIntOp>(loc, 1, 1); // true 1315 } 1316 cond = cond ? rewriter.create<arith::AndIOp>(loc, cond, clause) : clause; 1317 } 1318 } 1319 if (codegen.redVal) 1320 types.push_back(codegen.redVal.getType()); 1321 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, types, cond, /*else=*/true); 1322 rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); 1323 return ifOp; 1324 } 1325 1326 /// Generates end of true branch of if-statement within a while-loop. 1327 static void endIf(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1328 linalg::GenericOp op, scf::IfOp ifOp, Value ifInput) { 1329 if (codegen.redVal) { 1330 rewriter.create<scf::YieldOp>(op.getLoc(), codegen.redVal); 1331 updateReduc(merger, codegen, ifInput); 1332 } 1333 rewriter.setInsertionPointToStart(&ifOp.elseRegion().front()); 1334 } 1335 1336 //===----------------------------------------------------------------------===// 1337 // Sparse compiler synthesis methods (loop sequence). 1338 //===----------------------------------------------------------------------===// 1339 1340 /// Starts a loop sequence at given level. Returns true if 1341 /// the universal loop index must be maintained at this level. 1342 static bool startLoopSeq(Merger &merger, CodeGen &codegen, 1343 PatternRewriter &rewriter, linalg::GenericOp op, 1344 std::vector<unsigned> &topSort, unsigned exp, 1345 unsigned at, unsigned idx, unsigned ldx, 1346 unsigned lts) { 1347 assert(codegen.curVecLength == 1); 1348 assert(!codegen.loops[idx]); 1349 // Emit invariants at this loop sequence level. 1350 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*atStart=*/true); 1351 // Emit further intitialization at this loop sequence level. 1352 unsigned l0 = merger.set(lts)[0]; 1353 bool needsUniv = 1354 genInit(merger, codegen, rewriter, op, topSort, at, merger.lat(l0).bits); 1355 // Maintain the universal index only if it is actually 1356 // consumed by a subsequent lattice point. 1357 if (needsUniv) { 1358 unsigned lsize = merger.set(lts).size(); 1359 for (unsigned i = 1; i < lsize; i++) { 1360 unsigned li = merger.set(lts)[i]; 1361 if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) 1362 return true; 1363 } 1364 } 1365 return false; 1366 } 1367 1368 /// Starts a single loop in current sequence. 1369 static Operation *startLoop(Merger &merger, CodeGen &codegen, 1370 PatternRewriter &rewriter, linalg::GenericOp op, 1371 std::vector<unsigned> &topSort, unsigned at, 1372 unsigned li, bool needsUniv) { 1373 assert(codegen.curVecLength == 1); 1374 // Emit the for/while-loop control. 1375 Operation *loop = genLoop(merger, codegen, rewriter, op, topSort, at, 1376 needsUniv, merger.lat(li).simple); 1377 // Emit the locals for this loop. 1378 genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, 1379 merger.lat(li).bits); 1380 return loop; 1381 } 1382 1383 /// Ends a single loop in current sequence. Returns new values for needsUniv. 1384 static bool endLoop(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1385 linalg::GenericOp op, Operation *loop, unsigned idx, 1386 unsigned li, bool needsUniv) { 1387 codegen.curVecLength = 1; 1388 // End a while-loop. 1389 if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) { 1390 genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv, 1391 merger.lat(li).bits, whileOp); 1392 return needsUniv; 1393 } 1394 // End a for-loop. 1395 genForInduction(merger, codegen, rewriter, op, loop); 1396 return false; 1397 } 1398 1399 /// Ends a loop sequence at given level. 1400 static void endLoopSeq(Merger &merger, CodeGen &codegen, 1401 PatternRewriter &rewriter, linalg::GenericOp op, 1402 unsigned exp, unsigned idx, unsigned ldx) { 1403 assert(codegen.curVecLength == 1); 1404 codegen.loops[idx] = Value(); 1405 // Bring a pending reduction back from SIMD form when sequence ends. 1406 if (codegen.redVal) 1407 if (auto vtp = codegen.redVal.getType().dyn_cast<VectorType>()) 1408 updateReduc(merger, codegen, 1409 genVectorReducEnd(codegen, rewriter, op.getLoc(), vtp)); 1410 // Unmark bookkeeping of invariants and loop index. 1411 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*atStart=*/false); 1412 } 1413 1414 /// Recursively generates code while computing iteration lattices in order 1415 /// to manage the complexity of implementing co-iteration over unions 1416 /// and intersections of sparse iterations spaces. 1417 static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1418 linalg::GenericOp op, std::vector<unsigned> &topSort, 1419 unsigned exp, unsigned at) { 1420 // At each leaf, assign remaining tensor (sub)expression to output tensor. 1421 if (at == topSort.size()) { 1422 Value rhs = genExp(merger, codegen, rewriter, op, exp); 1423 genTensorStore(merger, codegen, rewriter, op, rhs); 1424 return; 1425 } 1426 1427 // Construct iteration lattices for current loop index, with L0 at top. 1428 unsigned idx = topSort[at]; 1429 unsigned ldx = at == 0 ? -1u : topSort[at - 1]; 1430 unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx)); 1431 1432 // Start a loop sequence. 1433 bool needsUniv = startLoopSeq(merger, codegen, rewriter, op, topSort, exp, at, 1434 idx, ldx, lts); 1435 1436 // Emit a loop for every lattice point L0 >= Li in this loop sequence. 1437 unsigned lsize = merger.set(lts).size(); 1438 for (unsigned i = 0; i < lsize; i++) { 1439 // Start a loop. 1440 unsigned li = merger.set(lts)[i]; 1441 Operation *loop = 1442 startLoop(merger, codegen, rewriter, op, topSort, at, li, needsUniv); 1443 1444 // Visit all lattices points with Li >= Lj to generate the 1445 // loop-body, possibly with if statements for coiteration. 1446 Value ifInput = codegen.redVal; 1447 bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr; 1448 for (unsigned j = 0; j < lsize; j++) { 1449 unsigned lj = merger.set(lts)[j]; 1450 unsigned ej = merger.lat(lj).exp; 1451 if (li == lj || merger.latGT(li, lj)) { 1452 // Recurse into body of each branch. 1453 if (isWhile) { 1454 scf::IfOp ifOp = 1455 genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple); 1456 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1457 endIf(merger, codegen, rewriter, op, ifOp, ifInput); 1458 } else { 1459 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1460 } 1461 } 1462 } 1463 1464 // End a loop. 1465 needsUniv = 1466 endLoop(merger, codegen, rewriter, op, loop, idx, li, needsUniv); 1467 } 1468 1469 // End a loop sequence. 1470 endLoopSeq(merger, codegen, rewriter, op, exp, idx, ldx); 1471 } 1472 1473 /// Converts the result computed by the sparse kernel into the required form. 1474 static void genResult(Merger &merger, CodeGen &codegen, 1475 PatternRewriter &rewriter, linalg::GenericOp op) { 1476 OpOperand *lhs = op.getOutputOperand(0); 1477 Type resType = lhs->get().getType(); 1478 Value result; 1479 if (getSparseTensorEncoding(resType)) { 1480 // The sparse tensor rematerializes from the original sparse tensor's 1481 // underlying sparse storage format. 1482 rewriter.replaceOpWithNewOp<LoadOp>(op, resType, lhs->get(), 1483 codegen.sparseOut == lhs); 1484 } else { 1485 // To rematerialize an non-annotated tensor, simply load it 1486 // from the bufferized value. 1487 Value val = codegen.buffers.back(); // value array 1488 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val); 1489 } 1490 } 1491 1492 //===----------------------------------------------------------------------===// 1493 // Sparse compiler rewriting methods. 1494 //===----------------------------------------------------------------------===// 1495 1496 namespace { 1497 1498 /// Sparse rewriting rule for generic Lingalg operation. 1499 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> { 1500 public: 1501 GenericOpSparsifier(MLIRContext *context, SparsificationOptions o) 1502 : OpRewritePattern<linalg::GenericOp>(context), options(o) {} 1503 1504 LogicalResult matchAndRewrite(linalg::GenericOp op, 1505 PatternRewriter &rewriter) const override { 1506 // Detects sparse annotations and translate the per-dimension sparsity 1507 // information for all tensors to loop indices in the kernel. 1508 assert(op.getNumOutputs() == 1); 1509 unsigned numTensors = op.getNumInputsAndOutputs(); 1510 unsigned numLoops = op.iterator_types().getValue().size(); 1511 Merger merger(numTensors, numLoops); 1512 if (!findSparseAnnotations(merger, op)) 1513 return failure(); 1514 1515 // Computes a topologically sorted iteration graph to ensure 1516 // tensors are visited in natural index order. Fails on cycles. 1517 // This assumes that higher-level passes have already put the 1518 // tensors in each tensor expression in a feasible order. 1519 std::vector<unsigned> topSort; 1520 if (!computeIterationGraph(merger, op, topSort, 1521 SortMask::kIncludeUndef | 1522 SortMask::kIncludeDense) && 1523 !computeIterationGraph(merger, op, topSort, SortMask::kIncludeUndef) && 1524 !computeIterationGraph(merger, op, topSort, SortMask::kIncludeDense) && 1525 !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly)) 1526 return failure(); 1527 1528 // Builds the tensor expression for the Linalg operation in SSA form. 1529 Optional<unsigned> optExp = merger.buildTensorExpFromLinalg(op); 1530 if (!optExp.hasValue()) 1531 return failure(); 1532 unsigned exp = optExp.getValue(); 1533 1534 // Rejects an inadmissable tensor expression. 1535 OpOperand *sparseOut = nullptr; 1536 unsigned outerParNest = 0; 1537 if (!isAdmissableTensorExp(merger, op, topSort, exp, &sparseOut, 1538 outerParNest)) 1539 return failure(); 1540 1541 // Recursively generates code. 1542 merger.setHasSparseOut(sparseOut != nullptr); 1543 CodeGen codegen(options, numTensors, numLoops, sparseOut, outerParNest); 1544 genBuffers(merger, codegen, rewriter, op); 1545 genStmt(merger, codegen, rewriter, op, topSort, exp, 0); 1546 genResult(merger, codegen, rewriter, op); 1547 return success(); 1548 } 1549 1550 private: 1551 /// Options to control sparse code generation. 1552 SparsificationOptions options; 1553 }; 1554 1555 } // namespace 1556 1557 /// Populates the given patterns list with rewriting rules required for 1558 /// the sparsification of linear algebra operations. 1559 void mlir::populateSparsificationPatterns( 1560 RewritePatternSet &patterns, const SparsificationOptions &options) { 1561 patterns.add<GenericOpSparsifier>(patterns.getContext(), options); 1562 } 1563