1 //===- Sparsification.cpp - Implementation of sparsification --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements converting sparse tensor types to actual sparse code. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 16 #include "mlir/Dialect/Linalg/Utils/Utils.h" 17 #include "mlir/Dialect/MemRef/IR/MemRef.h" 18 #include "mlir/Dialect/SCF/SCF.h" 19 #include "mlir/Dialect/SCF/Transforms.h" 20 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 21 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 22 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 23 #include "mlir/Dialect/StandardOps/IR/Ops.h" 24 #include "mlir/Dialect/Vector/VectorOps.h" 25 #include "mlir/IR/Matchers.h" 26 #include "mlir/IR/TensorEncoding.h" 27 #include "llvm/ADT/SmallBitVector.h" 28 29 using namespace mlir; 30 using namespace mlir::sparse_tensor; 31 32 //===----------------------------------------------------------------------===// 33 // Declarations of data structures. 34 //===----------------------------------------------------------------------===// 35 36 namespace { 37 38 // Iteration graph sorting. 39 enum SortMask { kSparseOnly = 0x0, kIncludeDense = 0x1, kIncludeUndef = 0x2 }; 40 41 // Reduction kinds. 42 enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor }; 43 44 // Code generation. 45 struct CodeGen { 46 CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops) 47 : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors), 48 pointers(numTensors, std::vector<Value>(numLoops)), 49 indices(numTensors, std::vector<Value>(numLoops)), 50 highs(numTensors, std::vector<Value>(numLoops)), 51 pidxs(numTensors, std::vector<Value>(numLoops)), 52 idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(), 53 redKind(kNoReduc), curVecLength(1), curVecMask() {} 54 /// Sparsification options. 55 SparsificationOptions options; 56 /// Universal dense indices and upper bounds (by index). The loops array 57 /// is updated with the value of the universal dense index in the current 58 /// loop. The sizes array is set once with the inferred dimension sizes. 59 std::vector<Value> loops; 60 std::vector<Value> sizes; 61 /// Buffers for storing dense and sparse numerical values (by tensor). 62 /// This array is set once during bufferization of all tensors. 63 std::vector<Value> buffers; 64 /// Sparse storage schemes (1-D): pointers and indices (by tensor and index). 65 /// This array is set once during bufferization of all sparse tensors. 66 std::vector<std::vector<Value>> pointers; 67 std::vector<std::vector<Value>> indices; 68 /// Sparse iteration information (by tensor and index). These arrays 69 /// are updated to remain current within the current loop. 70 std::vector<std::vector<Value>> highs; 71 std::vector<std::vector<Value>> pidxs; 72 std::vector<std::vector<Value>> idxs; 73 /// Current reduction, updated during code generation. When indices of a 74 /// reduction are exhausted, all inner loops can use a scalarized reduction. 75 unsigned redExp; 76 Value redVal; 77 Reduction redKind; 78 // Current vector length and mask. 79 unsigned curVecLength; 80 Value curVecMask; 81 }; 82 83 } // namespace 84 85 //===----------------------------------------------------------------------===// 86 // Sparse compiler analysis methods. 87 //===----------------------------------------------------------------------===// 88 89 /// Helper method to apply dimension ordering permutation. 90 static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d) { 91 if (enc) { 92 auto order = enc.getDimOrdering(); 93 if (order) { 94 assert(order.isPermutation()); 95 return order.getDimPosition(d); 96 } 97 } 98 return d; 99 } 100 101 /// Helper method to translate dim level type to internal representation. 102 static Dim toDim(const SparseTensorEncodingAttr &enc, unsigned d) { 103 if (enc) { 104 SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d]; 105 if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed) 106 return Dim::kSparse; 107 if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton) 108 return Dim::kSingle; 109 } 110 return Dim::kDense; 111 } 112 113 /// Helper method to inspect affine expressions. Rejects cases where the 114 /// same index is used more than once. Also rejects affine expressions 115 /// that are not a direct index for annotated tensors. 116 // TODO: accept more affine cases for sparse tensors 117 static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a, Dim dim, 118 bool isDense) { 119 switch (a.getKind()) { 120 case AffineExprKind::DimId: { 121 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 122 if (!merger.isDim(tensor, idx, Dim::kUndef)) 123 return false; // used more than once 124 merger.setDim(tensor, idx, dim); 125 return true; 126 } 127 case AffineExprKind::Add: 128 case AffineExprKind::Mul: { 129 if (!isDense) 130 return false; 131 auto binOp = a.cast<AffineBinaryOpExpr>(); 132 return findAffine(merger, tensor, binOp.getLHS(), dim, isDense) && 133 findAffine(merger, tensor, binOp.getRHS(), dim, isDense); 134 } 135 case AffineExprKind::Constant: 136 return isDense; 137 default: 138 return false; 139 } 140 } 141 142 /// Helper method to inspect sparse encodings in the tensor types. 143 /// Fills the per-dimension sparsity information for all tensors. 144 /// Returns true if the sparse annotations and affine subscript 145 /// expressions of all tensors are admissable. Returns false if 146 /// no annotations are found or inadmissable constructs occur. 147 static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) { 148 bool annotated = false; 149 for (OpOperand *t : op.getInputAndOutputOperands()) { 150 auto map = op.getTiedIndexingMap(t); 151 auto enc = getSparseTensorEncoding(t->get().getType()); 152 if (enc) 153 annotated = true; 154 assert(map.getNumResults() == op.getRank(t)); 155 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 156 unsigned tensor = t->getOperandNumber(); 157 AffineExpr a = map.getResult(perm(enc, d)); 158 if (!findAffine(merger, tensor, a, toDim(enc, d), !enc)) 159 return false; // inadmissable affine expression 160 } 161 } 162 return annotated; 163 } 164 165 /// A DFS helper to compute a topological sort. Note that recursion is 166 /// bounded by the number of implicit loops, which is always small. 167 /// Returns false when a cycle is detected. 168 static bool topSortDFS(unsigned i, std::vector<unsigned> &visit, 169 std::vector<unsigned> &topSort, 170 std::vector<std::vector<bool>> &adjM) { 171 if (visit[i] != 0) 172 return visit[i] != 1; // 1 denotes cycle! 173 visit[i] = 1; 174 for (unsigned j = 0, e = visit.size(); j < e; j++) 175 if (adjM[i][j]) 176 if (!topSortDFS(j, visit, topSort, adjM)) 177 return false; 178 visit[i] = 2; 179 topSort.push_back(i); 180 return true; 181 } 182 183 /// Helper method to add all constraints from the indices in one affine 184 /// expression before all indices in the other affine expression. For 185 /// example i0+i1 < i2+i3+1 yields i0<i2, i0<i3, i1<i2, and i1<i3. 186 static void addAffineOrderings(std::vector<std::vector<bool>> &adjM, 187 AffineExpr a, AffineExpr b, unsigned fidx) { 188 switch (a.getKind()) { 189 case AffineExprKind::DimId: { 190 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 191 if (b) 192 addAffineOrderings(adjM, b, AffineExpr(), idx); 193 else 194 adjM[fidx][idx] = true; 195 break; 196 } 197 case AffineExprKind::Add: 198 case AffineExprKind::Mul: { 199 auto binOp = a.cast<AffineBinaryOpExpr>(); 200 addAffineOrderings(adjM, binOp.getLHS(), b, fidx); 201 addAffineOrderings(adjM, binOp.getRHS(), b, fidx); 202 break; 203 } 204 default: 205 break; 206 } 207 } 208 209 /// Computes a topologically sorted iteration graph for the linalg operation. 210 /// Ensures all tensors are visited in natural index order. This is essential 211 /// for sparse storage formats since these only support access along fixed 212 /// dimensions. Even for dense storage formats, however, the natural index 213 /// order yields innermost unit-stride access with better spatial locality. 214 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op, 215 std::vector<unsigned> &topSort, 216 unsigned mask) { 217 // Set up an n x n from/to adjacency matrix of the iteration graph 218 // for the implicit loop indices i_0 .. i_n-1. 219 unsigned n = op.getNumLoops(); 220 std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false)); 221 222 // Iterate over the indexing maps of every tensor in the tensor expression. 223 for (OpOperand *t : op.getInputAndOutputOperands()) { 224 auto map = op.getTiedIndexingMap(t); 225 auto enc = getSparseTensorEncoding(t->get().getType()); 226 assert(map.getNumDims() == n); 227 // Skip dense tensor constraints when not requested. 228 if (!(mask & SortMask::kIncludeDense) && !enc) 229 continue; 230 // Each tensor expression and optional dimension ordering (row-major 231 // by default) puts an ordering constraint on the loop indices. For 232 // example, the tensor expresion A_ijk forces the ordering i < j < k 233 // on the loop indices if no explicit dimension ordering is given. 234 for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) { 235 AffineExpr f = map.getResult(perm(enc, d - 1)); 236 AffineExpr t = map.getResult(perm(enc, d)); 237 addAffineOrderings(adjM, f, t, 0); 238 } 239 // Push unrelated loops into sparse iteration space, so these 240 // will be skipped more often. 241 if (mask & SortMask::kIncludeUndef) { 242 unsigned tensor = t->getOperandNumber(); 243 for (unsigned i = 0; i < n; i++) 244 if (merger.isDim(tensor, i, Dim::kSparse)) 245 for (unsigned j = 0; j < n; j++) 246 if (merger.isDim(tensor, j, Dim::kUndef)) 247 adjM[i][j] = true; 248 } 249 } 250 251 // Topologically sort the iteration graph to determine loop order. 252 // Report failure for a cyclic iteration graph. 253 topSort.clear(); 254 topSort.reserve(n); 255 std::vector<unsigned> visit(n, 0); 256 for (unsigned i = 0; i < n; i++) 257 if (visit[i] == 0) 258 if (!topSortDFS(i, visit, topSort, adjM)) 259 return false; // cycle! 260 std::reverse(std::begin(topSort), std::end(topSort)); 261 return true; 262 } 263 264 /// Returns true if tensor has an in-place annotation. 265 static bool isInPlace(Value val) { 266 if (auto arg = val.dyn_cast<BlockArgument>()) 267 if (auto funcOp = dyn_cast<FuncOp>(arg.getOwner()->getParentOp())) 268 if (auto attr = funcOp.getArgAttrOfType<BoolAttr>( 269 arg.getArgNumber(), linalg::LinalgDialect::kInplaceableAttrName)) 270 return attr.getValue(); 271 return false; 272 } 273 274 /// Returns true if tensor materializes into the computation. 275 static bool isMaterializing(Value val) { 276 return val.getDefiningOp<linalg::InitTensorOp>() || 277 val.getDefiningOp<InitOp>(); 278 } 279 280 /// Returns true when the tensor expression is admissable for codegen. 281 /// Since all sparse input tensors are admissable, we just need to check 282 /// whether the output tensor in the tensor expression codegen is admissable. 283 static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op, 284 unsigned exp) { 285 OpOperand *lhs = op.getOutputOperand(0); 286 unsigned tensor = lhs->getOperandNumber(); 287 auto enc = getSparseTensorEncoding(lhs->get().getType()); 288 // An non-annotated output tensor is assumed dense, and becomes a random 289 // access n-dim memref. Admissable since insertions cannot occur. 290 if (!enc) 291 return true; 292 // An all-dense annotated "sparse" output tensor becomes a linearized random 293 // access 1-dim memref. Also admissable since insertions cannot occur. 294 bool allDense = true; 295 unsigned numLoops = op.iterator_types().getValue().size(); 296 for (unsigned i = 0; i < numLoops; i++) 297 if (merger.isDim(tensor, i, Dim::kSparse)) { 298 allDense = false; 299 break; 300 } 301 if (allDense) 302 return true; 303 // A tensor expression with a sparse output tensor that changes its values 304 // but not its nonzero structure, an operation called "simply dynamic" in 305 // [Bik96,Ch9], is also admissable without special codegen, provided 306 // the tensor's underlying sparse storage scheme can be modified in place. 307 if (merger.isConjunction(tensor, exp)) 308 return isInPlace(lhs->get()); 309 // Reject for now since this requires changes to the nonzero structure. 310 // TODO: implement "workspaces" [Kjolstad2019] 311 return false; 312 } 313 314 //===----------------------------------------------------------------------===// 315 // Sparse compiler synthesis methods (reductions). 316 //===----------------------------------------------------------------------===// 317 318 /// Maps reduction kind to name encoding. 319 static StringRef getReductionName(Reduction kind) { 320 switch (kind) { 321 case kNoReduc: 322 break; 323 case kSum: 324 return "add"; 325 case kProduct: 326 return "mul"; 327 case kAnd: 328 return "and"; 329 case kOr: 330 return "or"; 331 case kXor: 332 return "xor"; 333 } 334 llvm_unreachable("unknown reduction kind"); 335 } 336 337 /// Maps operation to reduction. 338 static Reduction getReduction(Kind kind) { 339 switch (kind) { 340 case Kind::kAddF: 341 case Kind::kAddI: 342 case Kind::kSubF: 343 case Kind::kSubI: 344 return kSum; 345 case Kind::kMulF: 346 case Kind::kMulI: 347 return kProduct; 348 case Kind::kAndI: 349 return kAnd; 350 case Kind::kOrI: 351 return kOr; 352 case Kind::kXorI: 353 return kXor; 354 default: 355 llvm_unreachable("unexpected reduction operator"); 356 } 357 } 358 359 /// Generates an initial value for a vector reduction, following the scheme 360 /// given in Chapter 5 of "The Software Vectorization Handbook", where the 361 /// initial scalar value is correctly embedded in the vector reduction value, 362 /// and a straightforward horizontal reduction will complete the operation. 363 static Value genVectorReducInit(CodeGen &codegen, PatternRewriter &rewriter, 364 Location loc, VectorType vtp) { 365 Value r = codegen.redVal; 366 switch (codegen.redKind) { 367 case kNoReduc: 368 break; 369 case kSum: 370 case kXor: { 371 // Initialize reduction vector to: | 0 | .. | 0 | r | 372 Attribute zero = rewriter.getZeroAttr(vtp); 373 Value vec = rewriter.create<arith::ConstantOp>(loc, vtp, zero); 374 return rewriter.create<vector::InsertElementOp>(loc, r, vec, 0); 375 } 376 case kProduct: { 377 // Initialize reduction vector to: | 1 | .. | 1 | r | 378 Type etp = vtp.getElementType(); 379 Attribute one; 380 if (etp.isa<FloatType>()) 381 one = rewriter.getFloatAttr(etp, 1.0); 382 else 383 one = rewriter.getIntegerAttr(etp, 1); 384 Value vec = rewriter.create<arith::ConstantOp>( 385 loc, vtp, DenseElementsAttr::get(vtp, one)); 386 return rewriter.create<vector::InsertElementOp>(loc, r, vec, 0); 387 } 388 case kAnd: 389 case kOr: 390 // Initialize reduction vector to: | r | .. | r | r | 391 return rewriter.create<vector::BroadcastOp>(loc, vtp, r); 392 } 393 llvm_unreachable("unknown reduction kind"); 394 } 395 396 /// Generates final value for a vector reduction. 397 static Value genVectorReducEnd(CodeGen &codegen, PatternRewriter &rewriter, 398 Location loc, VectorType vtp) { 399 StringRef name = getReductionName(codegen.redKind); 400 StringAttr kind = rewriter.getStringAttr(name); 401 return rewriter.create<vector::ReductionOp>(loc, vtp.getElementType(), kind, 402 codegen.redVal, ValueRange{}); 403 } 404 405 /// Updates scalarized reduction value. 406 static void updateReduc(Merger &merger, CodeGen &codegen, Value reduc) { 407 assert(codegen.redKind != kNoReduc); 408 codegen.redVal = merger.exp(codegen.redExp).val = reduc; 409 } 410 411 //===----------------------------------------------------------------------===// 412 // Sparse compiler synthesis methods (statements and expressions). 413 //===----------------------------------------------------------------------===// 414 415 /// Maps sparse integer option to actual integral storage type. 416 static Type genIntType(PatternRewriter &rewriter, unsigned width) { 417 if (width == 0) 418 return rewriter.getIndexType(); 419 return rewriter.getIntegerType(width); 420 } 421 422 /// Generates buffer for the output tensor. Note that all sparse kernels 423 /// assume that when all elements are written to (viz. x(i) = y(i) * z(i)), 424 /// the output buffer is already initialized to all zeroes and only nonzeroes 425 /// values are computed and written out. For updates (viz. x(i) += y(i) * z(i)), 426 /// only nonzeroes values are used for the updates and no assumption on the 427 /// original contents of the output buffer is necessary.. 428 static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter, 429 linalg::GenericOp op, MemRefType denseTp, 430 ArrayRef<Value> args) { 431 Location loc = op.getLoc(); 432 Value tensor = op.getOutputOperand(0)->get(); 433 // The output tensor simply could materialize from the buffer that will 434 // be generated for the tensor present in the outs() clause. This has 435 // the major advantage that the sparse kernel only updates the nonzero 436 // positions for the output tensor. 437 if (isInPlace(tensor)) 438 return rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 439 // By default, a new buffer is allocated which is initialized to the 440 // tensor defined in the outs() clause. This is always correct but 441 // introduces a dense initialization component that may negatively 442 // impact the running complexity of the sparse kernel. If the tensor 443 // materializes into the computation, we need to preserve the zero 444 // initialization assumption of all sparse output buffers. 445 if (isMaterializing(tensor)) { 446 Type tp = denseTp.getElementType(); 447 Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args); 448 Value zero = 449 rewriter.create<arith::ConstantOp>(loc, tp, rewriter.getZeroAttr(tp)); 450 rewriter.create<linalg::FillOp>(loc, zero, alloc); 451 return alloc; 452 } 453 Value init = rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 454 Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args); 455 rewriter.create<memref::CopyOp>(loc, init, alloc); 456 return alloc; 457 } 458 459 /// Local bufferization of all dense and sparse data structures. 460 /// This code enables testing the first prototype sparse compiler. 461 // TODO: replace this with a proliferated bufferization strategy 462 static void genBuffers(Merger &merger, CodeGen &codegen, 463 PatternRewriter &rewriter, linalg::GenericOp op) { 464 Location loc = op.getLoc(); 465 assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1); 466 // For every tensor, find lower and upper bound on dimensions, set the 467 // same bounds on loop indices, and obtain dense or sparse buffer(s). 468 SmallVector<Value, 4> args; 469 for (OpOperand *t : op.getInputAndOutputOperands()) { 470 unsigned tensor = t->getOperandNumber(); 471 auto shape = op.getShape(t); 472 auto map = op.getTiedIndexingMap(t); 473 auto enc = getSparseTensorEncoding(t->get().getType()); 474 // Scan all dimensions of current tensor. 475 args.clear(); 476 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 477 AffineExpr a = map.getResult(perm(enc, d)); 478 if (a.getKind() != AffineExprKind::DimId) 479 continue; // compound 480 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 481 // Handle sparse storage schemes. 482 if (merger.isDim(tensor, idx, Dim::kSparse)) { 483 auto dynShape = {ShapedType::kDynamicSize}; 484 auto ptrTp = MemRefType::get( 485 dynShape, genIntType(rewriter, enc.getPointerBitWidth())); 486 auto indTp = MemRefType::get( 487 dynShape, genIntType(rewriter, enc.getIndexBitWidth())); 488 Value dim = rewriter.create<arith::ConstantIndexOp>(loc, d); 489 // Generate sparse primitives to obtains pointer and indices. 490 codegen.pointers[tensor][idx] = 491 rewriter.create<ToPointersOp>(loc, ptrTp, t->get(), dim); 492 codegen.indices[tensor][idx] = 493 rewriter.create<ToIndicesOp>(loc, indTp, t->get(), dim); 494 } 495 // Find upper bound in current dimension. 496 unsigned p = perm(enc, d); 497 Value up = linalg::createOrFoldDimOp(rewriter, loc, t->get(), p); 498 if (shape[p] == MemRefType::kDynamicSize) 499 args.push_back(up); 500 assert(codegen.highs[tensor][idx] == nullptr); 501 codegen.sizes[idx] = codegen.highs[tensor][idx] = up; 502 } 503 // Perform the required bufferization. Dense inputs materialize 504 // from the input tensors. Dense outputs need special handling. 505 // Sparse inputs use sparse primitives to obtain the values. 506 // We also accept in-place all-dense annotated "sparse" outputs. 507 Type elementType = getElementTypeOrSelf(t->get().getType()); 508 if (!enc) { 509 // Non-annotated dense tensors. 510 auto denseTp = MemRefType::get(shape, elementType); 511 if (tensor < op.getNumInputs()) 512 codegen.buffers[tensor] = 513 rewriter.create<memref::BufferCastOp>(loc, denseTp, t->get()); 514 else 515 codegen.buffers[tensor] = 516 genOutputBuffer(codegen, rewriter, op, denseTp, args); 517 } else { 518 // Annotated sparse tensors. 519 auto dynShape = {ShapedType::kDynamicSize}; 520 auto sparseTp = MemRefType::get(dynShape, elementType); 521 codegen.buffers[tensor] = 522 rewriter.create<ToValuesOp>(loc, sparseTp, t->get()); 523 } 524 } 525 } 526 527 /// Constructs vector type. 528 static VectorType vectorType(CodeGen &codegen, Type etp) { 529 return VectorType::get(codegen.curVecLength, etp); 530 } 531 532 /// Constructs vector type from pointer. 533 static VectorType vectorType(CodeGen &codegen, Value ptr) { 534 return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType()); 535 } 536 537 /// Constructs vector iteration mask. 538 static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter, 539 Value iv, Value lo, Value hi, Value step) { 540 Location loc = iv.getLoc(); 541 VectorType mtp = vectorType(codegen, genIntType(rewriter, 1)); 542 // Special case if the vector length evenly divides the trip count (for 543 // example, "for i = 0, 128, 16"). A constant all-true mask is generated 544 // so that all subsequent masked memory operations are immediately folded 545 // into unconditional memory operations. 546 IntegerAttr loInt, hiInt, stepInt; 547 if (matchPattern(lo, m_Constant(&loInt)) && 548 matchPattern(hi, m_Constant(&hiInt)) && 549 matchPattern(step, m_Constant(&stepInt))) { 550 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) 551 return rewriter.create<vector::BroadcastOp>( 552 loc, mtp, rewriter.create<arith::ConstantIntOp>(loc, 1, 1)); 553 } 554 // Otherwise, generate a vector mask that avoids overrunning the upperbound 555 // during vector execution. Here we rely on subsequent loop optimizations to 556 // avoid executing the mask in all iterations, for example, by splitting the 557 // loop into an unconditional vector loop and a scalar cleanup loop. 558 auto minMap = AffineMap::get( 559 /*dimCount=*/2, /*symbolCount=*/1, 560 {rewriter.getAffineSymbolExpr(0), 561 rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)}, 562 rewriter.getContext()); 563 Value end = 564 rewriter.createOrFold<AffineMinOp>(loc, minMap, ValueRange{hi, iv, step}); 565 return rewriter.create<vector::CreateMaskOp>(loc, mtp, end); 566 } 567 568 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi]. 569 static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter, 570 Value ptr, ArrayRef<Value> args) { 571 Location loc = ptr.getLoc(); 572 VectorType vtp = vectorType(codegen, ptr); 573 Value pass = 574 rewriter.create<arith::ConstantOp>(loc, vtp, rewriter.getZeroAttr(vtp)); 575 if (args.back().getType().isa<VectorType>()) { 576 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 577 Value indexVec = args.back(); 578 scalarArgs.back() = rewriter.create<arith::ConstantIndexOp>(loc, 0); 579 return rewriter.create<vector::GatherOp>( 580 loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass); 581 } 582 return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args, 583 codegen.curVecMask, pass); 584 } 585 586 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs. 587 static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter, 588 Value rhs, Value ptr, ArrayRef<Value> args) { 589 Location loc = ptr.getLoc(); 590 if (args.back().getType().isa<VectorType>()) { 591 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 592 Value indexVec = args.back(); 593 scalarArgs.back() = rewriter.create<arith::ConstantIndexOp>(loc, 0); 594 rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, 595 codegen.curVecMask, rhs); 596 return; 597 } 598 rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask, 599 rhs); 600 } 601 602 /// Generates a vectorized invariant. Here we rely on subsequent loop 603 /// optimizations to hoist the invariant broadcast out of the vector loop. 604 static Value genVectorInvariantValue(CodeGen &codegen, 605 PatternRewriter &rewriter, Value val) { 606 VectorType vtp = vectorType(codegen, val.getType()); 607 return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val); 608 } 609 610 /// Generates an affine expression. 611 // 612 // TODO: generalize for sparse tensor subscripts 613 // 614 static Value genAffine(CodeGen &codegen, PatternRewriter &rewriter, 615 AffineExpr a, Location loc) { 616 switch (a.getKind()) { 617 case AffineExprKind::DimId: { 618 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 619 return codegen.loops[idx]; // universal dense index 620 } 621 case AffineExprKind::Add: { 622 auto binOp = a.cast<AffineBinaryOpExpr>(); 623 return rewriter.create<arith::AddIOp>( 624 loc, genAffine(codegen, rewriter, binOp.getLHS(), loc), 625 genAffine(codegen, rewriter, binOp.getRHS(), loc)); 626 } 627 case AffineExprKind::Mul: { 628 auto binOp = a.cast<AffineBinaryOpExpr>(); 629 return rewriter.create<arith::MulIOp>( 630 loc, genAffine(codegen, rewriter, binOp.getLHS(), loc), 631 genAffine(codegen, rewriter, binOp.getRHS(), loc)); 632 } 633 case AffineExprKind::Constant: { 634 int64_t c = a.cast<AffineConstantExpr>().getValue(); 635 return rewriter.create<arith::ConstantIndexOp>(loc, c); 636 } 637 default: 638 llvm_unreachable("unexpected affine subscript"); 639 } 640 } 641 642 /// Generates subscript for load/store on a dense or sparse tensor. 643 static Value genSubscript(CodeGen &codegen, PatternRewriter &rewriter, 644 linalg::GenericOp op, OpOperand *t, 645 SmallVector<Value, 4> &args) { 646 unsigned tensor = t->getOperandNumber(); 647 auto map = op.getTiedIndexingMap(t); 648 auto enc = getSparseTensorEncoding(t->get().getType()); 649 unsigned rank = map.getNumResults(); 650 if (enc) { 651 // Note that currently, all sparse subscripts are simple. 652 // TODO: accept affine too? 653 AffineExpr a = map.getResult(perm(enc, rank - 1)); 654 assert(a.getKind() == AffineExprKind::DimId); 655 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 656 assert(codegen.pidxs[tensor][idx] != nullptr); 657 args.push_back(codegen.pidxs[tensor][idx]); // position index 658 } else { 659 for (unsigned d = 0; d < rank; d++) { 660 AffineExpr a = map.getResult(perm(enc, d)); 661 args.push_back(genAffine(codegen, rewriter, a, op.getLoc())); 662 } 663 } 664 return codegen.buffers[tensor]; 665 } 666 667 /// Generates a load on a dense or sparse tensor. 668 static Value genTensorLoad(Merger &merger, CodeGen &codegen, 669 PatternRewriter &rewriter, linalg::GenericOp op, 670 unsigned exp) { 671 // Test if the load was hoisted to a higher loop nest. 672 Value val = merger.exp(exp).val; 673 if (val) { 674 if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>()) 675 return genVectorInvariantValue(codegen, rewriter, val); 676 return val; 677 } 678 // Actual load. 679 SmallVector<Value, 4> args; 680 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 681 Value ptr = genSubscript(codegen, rewriter, op, t, args); 682 if (codegen.curVecLength > 1) 683 return genVectorLoad(codegen, rewriter, ptr, args); 684 return rewriter.create<memref::LoadOp>(op.getLoc(), ptr, args); 685 } 686 687 /// Generates a store on a dense or sparse tensor. 688 static void genTensorStore(Merger &merger, CodeGen &codegen, 689 PatternRewriter &rewriter, linalg::GenericOp op, 690 Value rhs) { 691 // Test if this is a scalarized reduction. 692 if (codegen.redVal) { 693 if (codegen.curVecLength > 1) 694 rhs = rewriter.create<SelectOp>(op.getLoc(), codegen.curVecMask, rhs, 695 codegen.redVal); 696 updateReduc(merger, codegen, rhs); 697 return; 698 } 699 // Actual store. 700 SmallVector<Value, 4> args; 701 OpOperand *t = op.getOutputOperand(0); 702 Value ptr = genSubscript(codegen, rewriter, op, t, args); 703 if (codegen.curVecLength > 1) 704 genVectorStore(codegen, rewriter, rhs, ptr, args); 705 else 706 rewriter.create<memref::StoreOp>(op.getLoc(), rhs, ptr, args); 707 } 708 709 /// Generates a pointer/index load from the sparse storage scheme. Narrower 710 /// data types need to be zero extended before casting the value into the 711 /// index type used for looping and indexing. 712 static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc, 713 Value ptr, Value s) { 714 // See https://llvm.org/docs/GetElementPtr.html for some background on 715 // the complications described below. 716 if (codegen.curVecLength > 1) { 717 // Since the index vector is used in a subsequent gather/scatter operations, 718 // which effectively defines an unsigned pointer + signed index, we must 719 // zero extend the vector to an index width. For 8-bit and 16-bit values, 720 // an 32-bit index width suffices. For 32-bit values, zero extending the 721 // elements into 64-bit loses some performance since the 32-bit indexed 722 // gather/scatter is more efficient than the 64-bit index variant (if the 723 // negative 32-bit index space is unused, the enableSIMDIndex32 flag can 724 // preserve this performance). For 64-bit values, there is no good way 725 // to state that the indices are unsigned, with creates the potential of 726 // incorrect address calculations in the unlikely case we need such 727 // extremely large offsets. 728 Type etp = ptr.getType().cast<MemRefType>().getElementType(); 729 Value vload = genVectorLoad(codegen, rewriter, ptr, {s}); 730 if (!etp.isa<IndexType>()) { 731 if (etp.getIntOrFloatBitWidth() < 32) 732 vload = rewriter.create<arith::ExtUIOp>( 733 loc, vload, vectorType(codegen, genIntType(rewriter, 32))); 734 else if (etp.getIntOrFloatBitWidth() < 64 && 735 !codegen.options.enableSIMDIndex32) 736 vload = rewriter.create<arith::ExtUIOp>( 737 loc, vload, vectorType(codegen, genIntType(rewriter, 64))); 738 } 739 return vload; 740 } 741 // For the scalar case, we simply zero extend narrower indices into 64-bit 742 // values before casting to index without a performance penalty. Here too, 743 // however, indices that already are 64-bit, in theory, cannot express the 744 // full range as explained above. 745 Value load = rewriter.create<memref::LoadOp>(loc, ptr, s); 746 if (!load.getType().isa<IndexType>()) { 747 if (load.getType().getIntOrFloatBitWidth() < 64) 748 load = 749 rewriter.create<arith::ExtUIOp>(loc, load, genIntType(rewriter, 64)); 750 load = 751 rewriter.create<arith::IndexCastOp>(loc, load, rewriter.getIndexType()); 752 } 753 return load; 754 } 755 756 /// Generates an invariant value. 757 static Value genInvariantValue(Merger &merger, CodeGen &codegen, 758 PatternRewriter &rewriter, unsigned exp) { 759 Value val = merger.exp(exp).val; 760 if (codegen.curVecLength > 1) 761 return genVectorInvariantValue(codegen, rewriter, val); 762 return val; 763 } 764 765 /// Generates an address computation "sz * p + i". 766 static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter, 767 Location loc, Value size, Value p, Value i) { 768 Value mul = rewriter.create<arith::MulIOp>(loc, size, p); 769 if (auto vtp = i.getType().dyn_cast<VectorType>()) { 770 Value inv = 771 rewriter.create<arith::IndexCastOp>(loc, mul, vtp.getElementType()); 772 mul = genVectorInvariantValue(codegen, rewriter, inv); 773 } 774 return rewriter.create<arith::AddIOp>(loc, mul, i); 775 } 776 777 /// Recursively generates tensor expression. 778 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 779 linalg::GenericOp op, unsigned exp) { 780 Location loc = op.getLoc(); 781 if (exp == -1u) 782 return Value(); 783 if (merger.exp(exp).kind == Kind::kTensor) 784 return genTensorLoad(merger, codegen, rewriter, op, exp); 785 if (merger.exp(exp).kind == Kind::kInvariant) 786 return genInvariantValue(merger, codegen, rewriter, exp); 787 Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0); 788 Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1); 789 return merger.buildExp(rewriter, loc, exp, v0, v1); 790 } 791 792 /// Determines if affine expression is invariant. 793 static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a, 794 unsigned ldx, bool &atLevel) { 795 switch (a.getKind()) { 796 case AffineExprKind::DimId: { 797 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 798 if (idx == ldx) 799 atLevel = true; 800 return codegen.loops[idx] != nullptr; // no longer in play? 801 } 802 case AffineExprKind::Add: 803 case AffineExprKind::Mul: { 804 auto binOp = a.cast<AffineBinaryOpExpr>(); 805 return isInvariantAffine(codegen, binOp.getLHS(), ldx, atLevel) && 806 isInvariantAffine(codegen, binOp.getRHS(), ldx, atLevel); 807 } 808 default: 809 return true; 810 } 811 } 812 813 /// Hoists loop invariant tensor loads for which indices have been exhausted. 814 static void genInvariants(Merger &merger, CodeGen &codegen, 815 PatternRewriter &rewriter, linalg::GenericOp op, 816 unsigned exp, unsigned ldx, bool atStart, 817 Kind last = Kind::kTensor) { 818 if (exp == -1u) 819 return; 820 if (merger.exp(exp).kind == Kind::kTensor) { 821 // Inspect tensor indices. 822 bool atLevel = ldx == -1u; 823 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 824 auto map = op.getTiedIndexingMap(t); 825 auto enc = getSparseTensorEncoding(t->get().getType()); 826 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 827 AffineExpr a = map.getResult(perm(enc, d)); 828 if (!isInvariantAffine(codegen, a, ldx, atLevel)) 829 return; // still in play 830 } 831 // All exhausted at this level (atLevel denotes exactly at this level). 832 if (!atLevel) 833 return; 834 OpOperand *lhs = op.getOutputOperand(0); 835 if (lhs == t) { 836 // Start or end a scalarized reduction 837 if (atStart) { 838 Value load = genTensorLoad(merger, codegen, rewriter, op, exp); 839 codegen.redKind = getReduction(last); 840 codegen.redExp = exp; 841 updateReduc(merger, codegen, load); 842 } else { 843 Value redVal = codegen.redVal; 844 updateReduc(merger, codegen, Value()); 845 codegen.redExp = -1u; 846 codegen.redKind = kNoReduc; 847 genTensorStore(merger, codegen, rewriter, op, redVal); 848 } 849 } else { 850 // Start or end loop invariant hoisting of a tensor load. 851 merger.exp(exp).val = 852 atStart ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value(); 853 } 854 } else if (merger.exp(exp).kind != Kind::kInvariant) { 855 // Traverse into the binary operations. Note that we only hoist 856 // tensor loads, since subsequent MLIR/LLVM passes know how to 857 // deal with all other kinds of derived loop invariants. 858 Kind last = merger.exp(exp).kind; 859 unsigned e0 = merger.exp(exp).children.e0; 860 unsigned e1 = merger.exp(exp).children.e1; 861 genInvariants(merger, codegen, rewriter, op, e0, ldx, atStart, last); 862 genInvariants(merger, codegen, rewriter, op, e1, ldx, atStart, last); 863 } 864 } 865 866 /// Generates initialization code for the subsequent loop sequence at 867 /// current index level. Returns true if the loop sequence needs to 868 /// maintain the universal index. 869 static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 870 linalg::GenericOp op, std::vector<unsigned> &topSort, 871 unsigned at, llvm::BitVector &inits) { 872 bool needsUniv = false; 873 Location loc = op.getLoc(); 874 unsigned idx = topSort[at]; 875 876 // Initialize sparse positions. 877 for (unsigned b = 0, be = inits.size(); b < be; b++) { 878 if (inits[b]) { 879 unsigned tensor = merger.tensor(b); 880 assert(idx == merger.index(b)); 881 if (merger.isDim(b, Dim::kSparse)) { 882 // Initialize sparse index. 883 unsigned pat = at; 884 for (; pat != 0; pat--) { 885 if (codegen.pidxs[tensor][topSort[pat - 1]]) 886 break; 887 } 888 Value ptr = codegen.pointers[tensor][idx]; 889 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 890 Value p0 = (pat == 0) ? rewriter.create<arith::ConstantIndexOp>(loc, 0) 891 : codegen.pidxs[tensor][topSort[pat - 1]]; 892 codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0); 893 Value p1 = rewriter.create<arith::AddIOp>(loc, p0, one); 894 codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1); 895 } else { 896 // Dense index still in play. 897 needsUniv = true; 898 } 899 } 900 } 901 902 // Initialize the universal dense index. 903 codegen.loops[idx] = rewriter.create<arith::ConstantIndexOp>(loc, 0); 904 return needsUniv; 905 } 906 907 /// Returns vectorization strategy. Any implicit inner loop in the Linalg 908 /// operation is a candidate. Whether it is actually converted to SIMD code 909 /// depends on the requested strategy. 910 static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) { 911 switch (codegen.options.vectorizationStrategy) { 912 case SparseVectorizationStrategy::kNone: 913 return false; 914 case SparseVectorizationStrategy::kDenseInnerLoop: 915 return isInner && !isSparse; 916 case SparseVectorizationStrategy::kAnyStorageInnerLoop: 917 return isInner; 918 } 919 llvm_unreachable("unexpected vectorization strategy"); 920 } 921 922 /// Returns parallelization strategy. Any implicit loop in the Linalg operation 923 /// that is marked "parallel" is a candidate. Whether it is actually converted 924 /// to a parallel operation depends on the requested strategy. 925 static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction, 926 bool isSparse, bool isVector) { 927 switch (codegen.options.parallelizationStrategy) { 928 case SparseParallelizationStrategy::kNone: 929 return false; 930 case SparseParallelizationStrategy::kDenseOuterLoop: 931 return isOuter && !isSparse && !isReduction && !isVector; 932 case SparseParallelizationStrategy::kAnyStorageOuterLoop: 933 return isOuter && !isReduction && !isVector; 934 case SparseParallelizationStrategy::kDenseAnyLoop: 935 return !isSparse && !isReduction && !isVector; 936 case SparseParallelizationStrategy::kAnyStorageAnyLoop: 937 return !isReduction && !isVector; 938 } 939 llvm_unreachable("unexpected parallelization strategy"); 940 } 941 942 /// Checks unit stride for dense tensors. The iteration graph may have ignored 943 /// dense access patterns in order to avoid cycles (sparse access patterns are 944 /// always placed innermost), but that means dense access has become strided. 945 /// This prevents effective vectorization. 946 static bool denseUnitStrides(Merger &merger, linalg::GenericOp op, 947 unsigned idx) { 948 for (OpOperand *t : op.getInputAndOutputOperands()) { 949 if (!getSparseTensorEncoding(t->get().getType())) { 950 auto map = op.getTiedIndexingMap(t); 951 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 952 AffineExpr a = map.getResult(d); 953 // Report non-unit stride if innermost index appears at an outer 954 // dimension (true non-unit stride) or if the innermost index appears 955 // in a compound subscript in the innermost dimension. Even if the 956 // latter is unit stride, it does not play well with scatter/gather. 957 // TODO: accept unit stride affine innermost like a[i,j+k+1]? 958 if (a.isFunctionOfDim(idx) && 959 ((d != rank - 1) || (a.getKind() != AffineExprKind::DimId))) 960 return false; 961 } 962 } 963 } 964 return true; 965 } 966 967 /// Generates a for-loop on a single index. 968 static Operation *genFor(Merger &merger, CodeGen &codegen, 969 PatternRewriter &rewriter, linalg::GenericOp op, 970 bool isOuter, bool isInner, unsigned idx, 971 llvm::BitVector &indices) { 972 unsigned fb = indices.find_first(); 973 unsigned tensor = merger.tensor(fb); 974 assert(idx == merger.index(fb)); 975 auto iteratorTypes = op.iterator_types().getValue(); 976 bool isReduction = isReductionIterator(iteratorTypes[idx]); 977 bool isSparse = merger.isDim(fb, Dim::kSparse); 978 bool isVector = isVectorFor(codegen, isInner, isSparse) && 979 denseUnitStrides(merger, op, idx); 980 bool isParallel = 981 isParallelFor(codegen, isOuter, isReduction, isSparse, isVector); 982 983 // Prepare vector length. 984 if (isVector) 985 codegen.curVecLength = codegen.options.vectorLength; 986 987 // Loop bounds and increment. 988 Location loc = op.getLoc(); 989 Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx]; 990 Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx]; 991 Value step = 992 rewriter.create<arith::ConstantIndexOp>(loc, codegen.curVecLength); 993 994 // Emit a parallel loop. 995 if (isParallel) { 996 assert(!isVector); 997 scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step); 998 if (isSparse) 999 codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0]; 1000 else 1001 codegen.loops[idx] = parOp.getInductionVars()[0]; 1002 rewriter.setInsertionPointToStart(parOp.getBody()); 1003 return parOp; 1004 } 1005 1006 // Emit a sequential or vector loop. 1007 SmallVector<Value, 4> operands; 1008 if (codegen.redVal) { 1009 // In a vector loop, bring reduction into SIMD form, if not already. 1010 if (isVector && !codegen.redVal.getType().isa<VectorType>()) { 1011 VectorType vtp = vectorType(codegen, codegen.redVal.getType()); 1012 Value vred = genVectorReducInit(codegen, rewriter, loc, vtp); 1013 updateReduc(merger, codegen, vred); 1014 } 1015 operands.push_back(codegen.redVal); 1016 } 1017 scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands); 1018 if (codegen.redVal) 1019 updateReduc(merger, codegen, forOp.getRegionIterArgs().front()); 1020 // Assign induction variable to sparse or dense index. 1021 Value iv = forOp.getInductionVar(); 1022 if (isSparse) 1023 codegen.pidxs[tensor][idx] = iv; 1024 else 1025 codegen.loops[idx] = iv; 1026 rewriter.setInsertionPointToStart(forOp.getBody()); 1027 // Share vector iteration mask between all subsequent loads/stores. 1028 if (isVector) 1029 codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step); 1030 return forOp; 1031 } 1032 1033 /// Emit a while-loop for co-iteration over multiple indices. 1034 static Operation *genWhile(Merger &merger, CodeGen &codegen, 1035 PatternRewriter &rewriter, linalg::GenericOp op, 1036 unsigned idx, bool needsUniv, 1037 llvm::BitVector &indices) { 1038 SmallVector<Type, 4> types; 1039 SmallVector<Value, 4> operands; 1040 // Construct the while-loop with a parameter for each index. 1041 Type indexType = rewriter.getIndexType(); 1042 for (unsigned b = 0, be = indices.size(); b < be; b++) { 1043 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1044 unsigned tensor = merger.tensor(b); 1045 assert(idx == merger.index(b)); 1046 types.push_back(indexType); 1047 operands.push_back(codegen.pidxs[tensor][idx]); 1048 } 1049 } 1050 if (codegen.redVal) { 1051 types.push_back(codegen.redVal.getType()); 1052 operands.push_back(codegen.redVal); 1053 } 1054 if (needsUniv) { 1055 types.push_back(indexType); 1056 operands.push_back(codegen.loops[idx]); 1057 } 1058 assert(types.size() == operands.size()); 1059 Location loc = op.getLoc(); 1060 scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands); 1061 Block *before = rewriter.createBlock(&whileOp.before(), {}, types); 1062 Block *after = rewriter.createBlock(&whileOp.after(), {}, types); 1063 1064 // Build the "before" region, which effectively consists 1065 // of a conjunction of "i < upper" tests on all induction. 1066 rewriter.setInsertionPointToStart(&whileOp.before().front()); 1067 Value cond; 1068 unsigned o = 0; 1069 for (unsigned b = 0, be = indices.size(); b < be; b++) { 1070 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1071 unsigned tensor = merger.tensor(b); 1072 assert(idx == merger.index(b)); 1073 Value op1 = before->getArgument(o); 1074 Value op2 = codegen.highs[tensor][idx]; 1075 Value opc = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 1076 op1, op2); 1077 cond = cond ? rewriter.create<arith::AndIOp>(loc, cond, opc) : opc; 1078 codegen.pidxs[tensor][idx] = after->getArgument(o++); 1079 } 1080 } 1081 if (codegen.redVal) 1082 updateReduc(merger, codegen, after->getArgument(o++)); 1083 if (needsUniv) 1084 codegen.loops[idx] = after->getArgument(o++); 1085 assert(o == operands.size()); 1086 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 1087 rewriter.setInsertionPointToStart(&whileOp.after().front()); 1088 return whileOp; 1089 } 1090 1091 /// Generates a for-loop or a while-loop, depending on whether it implements 1092 /// singleton iteration or co-iteration over the given conjunction. 1093 static Operation *genLoop(Merger &merger, CodeGen &codegen, 1094 PatternRewriter &rewriter, linalg::GenericOp op, 1095 std::vector<unsigned> &topSort, unsigned at, 1096 bool needsUniv, llvm::BitVector &indices) { 1097 unsigned idx = topSort[at]; 1098 if (indices.count() == 1) { 1099 bool isOuter = at == 0; 1100 bool isInner = at == topSort.size() - 1; 1101 return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx, 1102 indices); 1103 } 1104 return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices); 1105 } 1106 1107 /// Generates the local variables for this loop, consisting of the sparse 1108 /// indices, restored universal dense index, and dense positions. 1109 static void genLocals(Merger &merger, CodeGen &codegen, 1110 PatternRewriter &rewriter, linalg::GenericOp op, 1111 std::vector<unsigned> &topSort, unsigned at, 1112 bool needsUniv, llvm::BitVector &locals) { 1113 Location loc = op.getLoc(); 1114 unsigned idx = topSort[at]; 1115 1116 // Initialize sparse indices. 1117 Value min; 1118 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1119 if (locals[b] && merger.isDim(b, Dim::kSparse)) { 1120 unsigned tensor = merger.tensor(b); 1121 assert(idx == merger.index(b)); 1122 Value ptr = codegen.indices[tensor][idx]; 1123 Value s = codegen.pidxs[tensor][idx]; 1124 Value load = genLoad(codegen, rewriter, loc, ptr, s); 1125 codegen.idxs[tensor][idx] = load; 1126 if (!needsUniv) { 1127 if (min) { 1128 Value cmp = rewriter.create<arith::CmpIOp>( 1129 loc, arith::CmpIPredicate::ult, load, min); 1130 min = rewriter.create<SelectOp>(loc, cmp, load, min); 1131 } else { 1132 min = load; 1133 } 1134 } 1135 } 1136 } 1137 1138 // Merge dense universal index over minimum. 1139 if (min) { 1140 assert(!needsUniv); 1141 codegen.loops[idx] = min; 1142 } 1143 1144 // Initialize dense positions. Note that we generate dense indices of the 1145 // output tensor unconditionally, since they may not appear in the lattice, 1146 // but may be needed for linearized codegen. 1147 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1148 if ((locals[b] || merger.isOutTensor(b, idx)) && 1149 merger.isDim(b, Dim::kDense)) { 1150 unsigned tensor = merger.tensor(b); 1151 assert(idx == merger.index(b)); 1152 unsigned pat = at; 1153 for (; pat != 0; pat--) 1154 if (codegen.pidxs[tensor][topSort[pat - 1]]) 1155 break; 1156 Value p = (pat == 0) ? rewriter.create<arith::ConstantIndexOp>(loc, 0) 1157 : codegen.pidxs[tensor][topSort[pat - 1]]; 1158 codegen.pidxs[tensor][idx] = genAddress( 1159 codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]); 1160 } 1161 } 1162 } 1163 1164 /// Generates the induction structure for a while-loop. 1165 static void genWhileInduction(Merger &merger, CodeGen &codegen, 1166 PatternRewriter &rewriter, linalg::GenericOp op, 1167 unsigned idx, bool needsUniv, 1168 llvm::BitVector &induction, 1169 scf::WhileOp whileOp) { 1170 Location loc = op.getLoc(); 1171 // Finalize each else branch of all if statements. 1172 if (codegen.redVal) { 1173 while (auto ifOp = dyn_cast_or_null<scf::IfOp>( 1174 rewriter.getInsertionBlock()->getParentOp())) { 1175 rewriter.create<scf::YieldOp>(loc, codegen.redVal); 1176 updateReduc(merger, codegen, ifOp.getResult(0)); 1177 rewriter.setInsertionPointAfter(ifOp); 1178 } 1179 } 1180 rewriter.setInsertionPointToEnd(&whileOp.after().front()); 1181 // Finalize the induction. Note that the induction could be performed 1182 // in the individual if-branches to avoid re-evaluating the conditions. 1183 // However, that would result in a rather elaborate forest of yield 1184 // instructions during code generation. Moreover, performing the induction 1185 // after the if-statements more closely resembles code generated by TACO. 1186 unsigned o = 0; 1187 SmallVector<Value, 4> operands; 1188 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 1189 for (unsigned b = 0, be = induction.size(); b < be; b++) { 1190 if (induction[b] && merger.isDim(b, Dim::kSparse)) { 1191 unsigned tensor = merger.tensor(b); 1192 assert(idx == merger.index(b)); 1193 Value op1 = codegen.idxs[tensor][idx]; 1194 Value op2 = codegen.loops[idx]; 1195 Value op3 = codegen.pidxs[tensor][idx]; 1196 Value cmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 1197 op1, op2); 1198 Value add = rewriter.create<arith::AddIOp>(loc, op3, one); 1199 operands.push_back(rewriter.create<SelectOp>(loc, cmp, add, op3)); 1200 codegen.pidxs[tensor][idx] = whileOp->getResult(o++); 1201 } 1202 } 1203 if (codegen.redVal) { 1204 operands.push_back(codegen.redVal); 1205 updateReduc(merger, codegen, whileOp->getResult(o++)); 1206 } 1207 if (needsUniv) { 1208 operands.push_back( 1209 rewriter.create<arith::AddIOp>(loc, codegen.loops[idx], one)); 1210 codegen.loops[idx] = whileOp->getResult(o++); 1211 } 1212 assert(o == operands.size()); 1213 rewriter.create<scf::YieldOp>(loc, operands); 1214 rewriter.setInsertionPointAfter(whileOp); 1215 } 1216 1217 /// Generates the induction structure for a for-loop. 1218 static void genForInduction(Merger &merger, CodeGen &codegen, 1219 PatternRewriter &rewriter, linalg::GenericOp op, 1220 Operation *loop) { 1221 Location loc = op.getLoc(); 1222 unsigned o = 0; 1223 SmallVector<Value, 4> operands; 1224 if (codegen.redVal) { 1225 operands.push_back(codegen.redVal); 1226 updateReduc(merger, codegen, loop->getResult(o++)); 1227 } 1228 assert(o == operands.size()); 1229 if (o > 0) 1230 rewriter.create<scf::YieldOp>(loc, operands); 1231 rewriter.setInsertionPointAfter(loop); 1232 } 1233 1234 /// Generates a single if-statement within a while-loop. 1235 static scf::IfOp genIf(Merger &merger, CodeGen &codegen, 1236 PatternRewriter &rewriter, linalg::GenericOp op, 1237 unsigned idx, llvm::BitVector &conditions) { 1238 Location loc = op.getLoc(); 1239 SmallVector<Type, 4> types; 1240 Value cond; 1241 for (unsigned b = 0, be = conditions.size(); b < be; b++) { 1242 if (conditions[b]) { 1243 unsigned tensor = merger.tensor(b); 1244 assert(idx == merger.index(b)); 1245 Value clause; 1246 if (merger.isDim(b, Dim::kSparse)) { 1247 Value op1 = codegen.idxs[tensor][idx]; 1248 Value op2 = codegen.loops[idx]; 1249 clause = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 1250 op1, op2); 1251 } else { 1252 clause = rewriter.create<arith::ConstantIntOp>(loc, 1, 1); // true 1253 } 1254 cond = cond ? rewriter.create<arith::AndIOp>(loc, cond, clause) : clause; 1255 } 1256 } 1257 if (codegen.redVal) 1258 types.push_back(codegen.redVal.getType()); 1259 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, types, cond, /*else=*/true); 1260 rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); 1261 return ifOp; 1262 } 1263 1264 /// Generates end of true branch of if-statement within a while-loop. 1265 static void endIf(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1266 linalg::GenericOp op, scf::IfOp ifOp, Value ifInput) { 1267 if (codegen.redVal) { 1268 rewriter.create<scf::YieldOp>(op.getLoc(), codegen.redVal); 1269 updateReduc(merger, codegen, ifInput); 1270 } 1271 rewriter.setInsertionPointToStart(&ifOp.elseRegion().front()); 1272 } 1273 1274 //===----------------------------------------------------------------------===// 1275 // Sparse compiler synthesis methods (loop sequence). 1276 //===----------------------------------------------------------------------===// 1277 1278 /// Starts a loop sequence at given level. Returns true if 1279 /// the universal loop index must be maintained at this level. 1280 static bool startLoopSeq(Merger &merger, CodeGen &codegen, 1281 PatternRewriter &rewriter, linalg::GenericOp op, 1282 std::vector<unsigned> &topSort, unsigned exp, 1283 unsigned at, unsigned idx, unsigned ldx, 1284 unsigned lts) { 1285 assert(codegen.curVecLength == 1); 1286 assert(!codegen.loops[idx]); 1287 // Emit invariants at this loop sequence level. 1288 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*atStart=*/true); 1289 // Emit further intitialization at this loop sequence level. 1290 unsigned l0 = merger.set(lts)[0]; 1291 bool needsUniv = 1292 genInit(merger, codegen, rewriter, op, topSort, at, merger.lat(l0).bits); 1293 // Maintain the universal index only if it is actually 1294 // consumed by a subsequent lattice point. 1295 if (needsUniv) { 1296 unsigned lsize = merger.set(lts).size(); 1297 for (unsigned i = 1; i < lsize; i++) { 1298 unsigned li = merger.set(lts)[i]; 1299 if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) 1300 return true; 1301 } 1302 } 1303 return false; 1304 } 1305 1306 /// Starts a single loop in current sequence. 1307 static Operation *startLoop(Merger &merger, CodeGen &codegen, 1308 PatternRewriter &rewriter, linalg::GenericOp op, 1309 std::vector<unsigned> &topSort, unsigned at, 1310 unsigned li, bool needsUniv) { 1311 assert(codegen.curVecLength == 1); 1312 // Emit the for/while-loop control. 1313 Operation *loop = genLoop(merger, codegen, rewriter, op, topSort, at, 1314 needsUniv, merger.lat(li).simple); 1315 // Emit the locals for this loop. 1316 genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, 1317 merger.lat(li).bits); 1318 return loop; 1319 } 1320 1321 /// Ends a single loop in current sequence. Returns new values for needsUniv. 1322 static bool endLoop(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1323 linalg::GenericOp op, Operation *loop, unsigned idx, 1324 unsigned li, bool needsUniv) { 1325 codegen.curVecLength = 1; 1326 // End a while-loop. 1327 if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) { 1328 genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv, 1329 merger.lat(li).bits, whileOp); 1330 return needsUniv; 1331 } 1332 // End a for-loop. 1333 genForInduction(merger, codegen, rewriter, op, loop); 1334 return false; 1335 } 1336 1337 /// Ends a loop sequence at given level. 1338 static void endLoopSeq(Merger &merger, CodeGen &codegen, 1339 PatternRewriter &rewriter, linalg::GenericOp op, 1340 unsigned exp, unsigned idx, unsigned ldx) { 1341 assert(codegen.curVecLength == 1); 1342 codegen.loops[idx] = Value(); 1343 // Bring a pending reduction back from SIMD form when sequence ends. 1344 if (codegen.redVal) 1345 if (auto vtp = codegen.redVal.getType().dyn_cast<VectorType>()) 1346 updateReduc(merger, codegen, 1347 genVectorReducEnd(codegen, rewriter, op.getLoc(), vtp)); 1348 // Unmark bookkeeping of invariants and loop index. 1349 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*atStart=*/false); 1350 } 1351 1352 /// Recursively generates code while computing iteration lattices in order 1353 /// to manage the complexity of implementing co-iteration over unions 1354 /// and intersections of sparse iterations spaces. 1355 static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1356 linalg::GenericOp op, std::vector<unsigned> &topSort, 1357 unsigned exp, unsigned at) { 1358 // At each leaf, assign remaining tensor (sub)expression to output tensor. 1359 if (at == topSort.size()) { 1360 Value rhs = genExp(merger, codegen, rewriter, op, exp); 1361 genTensorStore(merger, codegen, rewriter, op, rhs); 1362 return; 1363 } 1364 1365 // Construct iteration lattices for current loop index, with L0 at top. 1366 unsigned idx = topSort[at]; 1367 unsigned ldx = at == 0 ? -1u : topSort[at - 1]; 1368 unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx)); 1369 1370 // Start a loop sequence. 1371 bool needsUniv = startLoopSeq(merger, codegen, rewriter, op, topSort, exp, at, 1372 idx, ldx, lts); 1373 1374 // Emit a loop for every lattice point L0 >= Li in this loop sequence. 1375 unsigned lsize = merger.set(lts).size(); 1376 for (unsigned i = 0; i < lsize; i++) { 1377 // Start a loop. 1378 unsigned li = merger.set(lts)[i]; 1379 Operation *loop = 1380 startLoop(merger, codegen, rewriter, op, topSort, at, li, needsUniv); 1381 1382 // Visit all lattices points with Li >= Lj to generate the 1383 // loop-body, possibly with if statements for coiteration. 1384 Value ifInput = codegen.redVal; 1385 bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr; 1386 for (unsigned j = 0; j < lsize; j++) { 1387 unsigned lj = merger.set(lts)[j]; 1388 unsigned ej = merger.lat(lj).exp; 1389 if (li == lj || merger.latGT(li, lj)) { 1390 // Recurse into body of each branch. 1391 if (isWhile) { 1392 scf::IfOp ifOp = 1393 genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple); 1394 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1395 endIf(merger, codegen, rewriter, op, ifOp, ifInput); 1396 } else { 1397 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1398 } 1399 } 1400 } 1401 1402 // End a loop. 1403 needsUniv = 1404 endLoop(merger, codegen, rewriter, op, loop, idx, li, needsUniv); 1405 } 1406 1407 // End a loop sequence. 1408 endLoopSeq(merger, codegen, rewriter, op, exp, idx, ldx); 1409 } 1410 1411 /// Converts the result computed by the sparse kernel into the required form. 1412 static void genResult(Merger &merger, CodeGen &codegen, 1413 PatternRewriter &rewriter, linalg::GenericOp op) { 1414 Location loc = op.getLoc(); 1415 OpOperand *lhs = op.getOutputOperand(0); 1416 Type resType = lhs->get().getType(); 1417 unsigned tensor = lhs->getOperandNumber(); 1418 auto map = op.getTiedIndexingMap(lhs); 1419 auto enc = getSparseTensorEncoding(resType); 1420 Value result = codegen.buffers.back(); // value array 1421 if (enc) { 1422 // The sparse annotation unambigiously defines the arrays needed 1423 // to "reconstruct" the sparse tensor from the storage scheme 1424 // (even though lowering should never need this eventually). 1425 SmallVector<Value, 4> args; 1426 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 1427 AffineExpr a = map.getResult(perm(enc, d)); 1428 if (a.getKind() != AffineExprKind::DimId) 1429 continue; // compound 1430 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 1431 if (merger.isDim(tensor, idx, Dim::kSparse)) { 1432 args.push_back(codegen.pointers[tensor][idx]); 1433 args.push_back(codegen.indices[tensor][idx]); 1434 } 1435 } 1436 args.push_back(result); 1437 result = rewriter.create<ToTensorOp>(loc, resType, args); 1438 } else { 1439 // To "reconstruct" an non-annotated tensor, sipmly load it 1440 // from the bufferized value. 1441 result = rewriter.create<memref::TensorLoadOp>(loc, resType, result); 1442 } 1443 rewriter.replaceOp(op, result); 1444 } 1445 1446 //===----------------------------------------------------------------------===// 1447 // Sparse compiler rewriting methods. 1448 //===----------------------------------------------------------------------===// 1449 1450 namespace { 1451 1452 /// Sparse rewriting rule for generic Lingalg operation. 1453 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> { 1454 public: 1455 GenericOpSparsifier(MLIRContext *context, SparsificationOptions o) 1456 : OpRewritePattern<linalg::GenericOp>(context), options(o) {} 1457 1458 LogicalResult matchAndRewrite(linalg::GenericOp op, 1459 PatternRewriter &rewriter) const override { 1460 // Detects sparse annotations and translate the per-dimension sparsity 1461 // information for all tensors to loop indices in the kernel. 1462 assert(op.getNumOutputs() == 1); 1463 unsigned numTensors = op.getNumInputsAndOutputs(); 1464 unsigned numLoops = op.iterator_types().getValue().size(); 1465 Merger merger(numTensors, numLoops); 1466 if (!findSparseAnnotations(merger, op)) 1467 return failure(); 1468 1469 // Computes a topologically sorted iteration graph to ensure 1470 // tensors are visited in natural index order. Fails on cycles. 1471 // This assumes that higher-level passes have already put the 1472 // tensors in each tensor expression in a feasible order. 1473 std::vector<unsigned> topSort; 1474 if (!computeIterationGraph(merger, op, topSort, 1475 SortMask::kIncludeUndef | 1476 SortMask::kIncludeDense) && 1477 !computeIterationGraph(merger, op, topSort, SortMask::kIncludeUndef) && 1478 !computeIterationGraph(merger, op, topSort, SortMask::kIncludeDense) && 1479 !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly)) 1480 return failure(); 1481 1482 // Builds the tensor expression for the Linalg operation in SSA form. 1483 Optional<unsigned> optExp = merger.buildTensorExpFromLinalg(op); 1484 if (!optExp.hasValue()) 1485 return failure(); 1486 unsigned exp = optExp.getValue(); 1487 1488 // Rejects an inadmissable tensor expression. 1489 if (!isAdmissableTensorExp(merger, op, exp)) 1490 return failure(); 1491 1492 // Recursively generates code. 1493 CodeGen codegen(options, numTensors, numLoops); 1494 genBuffers(merger, codegen, rewriter, op); 1495 genStmt(merger, codegen, rewriter, op, topSort, exp, 0); 1496 genResult(merger, codegen, rewriter, op); 1497 return success(); 1498 } 1499 1500 private: 1501 /// Options to control sparse code generation. 1502 SparsificationOptions options; 1503 }; 1504 1505 } // namespace 1506 1507 /// Populates the given patterns list with rewriting rules required for 1508 /// the sparsification of linear algebra operations. 1509 void mlir::populateSparsificationPatterns( 1510 RewritePatternSet &patterns, const SparsificationOptions &options) { 1511 patterns.add<GenericOpSparsifier>(patterns.getContext(), options); 1512 } 1513