1 //===- Sparsification.cpp - Implementation of sparsification --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements converting sparse tensor types to actual sparse code. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 16 #include "mlir/Dialect/Linalg/Utils/Utils.h" 17 #include "mlir/Dialect/MemRef/IR/MemRef.h" 18 #include "mlir/Dialect/SCF/SCF.h" 19 #include "mlir/Dialect/SCF/Transforms.h" 20 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 21 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 22 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 23 #include "mlir/Dialect/StandardOps/IR/Ops.h" 24 #include "mlir/Dialect/Vector/VectorOps.h" 25 #include "mlir/IR/Matchers.h" 26 #include "mlir/IR/TensorEncoding.h" 27 #include "llvm/ADT/SmallBitVector.h" 28 29 using namespace mlir; 30 using namespace mlir::sparse_tensor; 31 32 //===----------------------------------------------------------------------===// 33 // Declarations of data structures. 34 //===----------------------------------------------------------------------===// 35 36 namespace { 37 38 // Iteration graph sorting. 39 enum SortMask { kSparseOnly = 0x0, kIncludeDense = 0x1, kIncludeUndef = 0x2 }; 40 41 // Reduction kinds. 42 enum Reduction { kSum, kProduct, kAnd, kOr, kXor }; 43 44 // Code generation. 45 struct CodeGen { 46 CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops) 47 : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors), 48 pointers(numTensors, std::vector<Value>(numLoops)), 49 indices(numTensors, std::vector<Value>(numLoops)), 50 highs(numTensors, std::vector<Value>(numLoops)), 51 pidxs(numTensors, std::vector<Value>(numLoops)), 52 idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(), 53 curVecLength(1), curVecMask() {} 54 /// Sparsification options. 55 SparsificationOptions options; 56 /// Universal dense indices and upper bounds (by index). The loops array 57 /// is updated with the value of the universal dense index in the current 58 /// loop. The sizes array is set once with the inferred dimension sizes. 59 std::vector<Value> loops; 60 std::vector<Value> sizes; 61 /// Buffers for storing dense and sparse numerical values (by tensor). 62 /// This array is set once during bufferization of all tensors. 63 std::vector<Value> buffers; 64 /// Sparse storage schemes (1-D): pointers and indices (by tensor and index). 65 /// This array is set once during bufferization of all sparse tensors. 66 std::vector<std::vector<Value>> pointers; 67 std::vector<std::vector<Value>> indices; 68 /// Sparse iteration information (by tensor and index). These arrays 69 /// are updated to remain current within the current loop. 70 std::vector<std::vector<Value>> highs; 71 std::vector<std::vector<Value>> pidxs; 72 std::vector<std::vector<Value>> idxs; 73 /// Current reduction, updated during code generation. When indices of a 74 /// reduction are exhausted, all inner loops can "scalarize" the reduction. 75 // TODO: currently only done for (a chain of) innermost for-loops, where it 76 // is most effective; we could generalize to more outer and while-loops. 77 unsigned redExp; 78 Value redVal; 79 Reduction redKind; 80 // Current vector length and mask. 81 unsigned curVecLength; 82 Value curVecMask; 83 }; 84 85 } // namespace 86 87 //===----------------------------------------------------------------------===// 88 // Sparse compiler analysis methods. 89 //===----------------------------------------------------------------------===// 90 91 /// Helper method to apply dimension ordering permutation. 92 static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d) { 93 if (enc) { 94 auto order = enc.getDimOrdering(); 95 if (order) { 96 assert(order.isPermutation()); 97 return order.getDimPosition(d); 98 } 99 } 100 return d; 101 } 102 103 /// Helper method to translate dim level type to internal representation. 104 static Dim toDim(const SparseTensorEncodingAttr &enc, unsigned d) { 105 if (enc) { 106 SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d]; 107 if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed) 108 return Dim::kSparse; 109 if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton) 110 return Dim::kSingle; 111 } 112 return Dim::kDense; 113 } 114 115 /// Helper method to inspect affine expressions. Rejects cases where the 116 /// same index is used in more than one dimension of a tensor. Also rejects 117 /// affine expressions that are not a direct index for annotated tensors. 118 /// TODO: accept more affine cases for sparse tensors 119 static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a, Dim dim, 120 bool isDense) { 121 switch (a.getKind()) { 122 case AffineExprKind::DimId: { 123 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 124 if (!merger.isDim(tensor, idx, Dim::kUndef)) 125 return false; // used more than once 126 merger.setDim(tensor, idx, dim); 127 return true; 128 } 129 case AffineExprKind::Add: 130 case AffineExprKind::Mul: { 131 if (!isDense) 132 return false; 133 auto binOp = a.cast<AffineBinaryOpExpr>(); 134 return findAffine(merger, tensor, binOp.getLHS(), dim, isDense) && 135 findAffine(merger, tensor, binOp.getRHS(), dim, isDense); 136 } 137 case AffineExprKind::Constant: 138 return isDense; 139 default: 140 return false; 141 } 142 } 143 144 /// Helper method to inspect sparse encodings in the tensor types. 145 /// Fills the per-dimension sparsity information for all tensors. 146 /// Returns true if the sparse annotations and affine subscript 147 /// expressions of all tensors are admissable. Returns false if 148 /// no annotations are found or inadmissable constructs occur. 149 static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) { 150 bool annotated = false; 151 for (OpOperand *t : op.getInputAndOutputOperands()) { 152 auto map = op.getTiedIndexingMap(t); 153 auto enc = getSparseTensorEncoding(t->get().getType()); 154 if (enc) 155 annotated = true; 156 assert(map.getNumResults() == op.getRank(t)); 157 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 158 unsigned tensor = t->getOperandNumber(); 159 AffineExpr a = map.getResult(perm(enc, d)); 160 if (!findAffine(merger, tensor, a, toDim(enc, d), !enc)) 161 return false; // inadmissable affine expression 162 } 163 } 164 return annotated; 165 } 166 167 /// A DFS helper to compute a topological sort. Note that recursion is 168 /// bounded by the number of implicit loops, which is always small. 169 /// Returns false when a cycle is detected. 170 static bool topSortDFS(unsigned i, std::vector<unsigned> &visit, 171 std::vector<unsigned> &topSort, 172 std::vector<std::vector<bool>> &adjM) { 173 if (visit[i] != 0) 174 return visit[i] != 1; // 1 denotes cycle! 175 visit[i] = 1; 176 for (unsigned j = 0, e = visit.size(); j < e; j++) 177 if (adjM[i][j]) 178 if (!topSortDFS(j, visit, topSort, adjM)) 179 return false; 180 visit[i] = 2; 181 topSort.push_back(i); 182 return true; 183 } 184 185 /// Helper method to add all constraints from the indices in one affine 186 /// expression before all indices in the other affine expression. For 187 /// example i0+i1 < i2+i3+1 yields i0<i2, i0<i3, i1<i2, and i1<i3. 188 static void addAffineOrderings(std::vector<std::vector<bool>> &adjM, 189 AffineExpr a, AffineExpr b, unsigned fidx) { 190 switch (a.getKind()) { 191 case AffineExprKind::DimId: { 192 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 193 if (b) 194 addAffineOrderings(adjM, b, AffineExpr(), idx); 195 else 196 adjM[fidx][idx] = true; 197 break; 198 } 199 case AffineExprKind::Add: 200 case AffineExprKind::Mul: { 201 auto binOp = a.cast<AffineBinaryOpExpr>(); 202 addAffineOrderings(adjM, binOp.getLHS(), b, fidx); 203 addAffineOrderings(adjM, binOp.getRHS(), b, fidx); 204 break; 205 } 206 default: 207 break; 208 } 209 } 210 211 /// Computes a topologically sorted iteration graph for the linalg operation. 212 /// Ensures all tensors are visited in natural index order. This is essential 213 /// for sparse storage formats since these only support access along fixed 214 /// dimensions. Even for dense storage formats, however, the natural index 215 /// order yields innermost unit-stride access with better spatial locality. 216 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op, 217 std::vector<unsigned> &topSort, 218 unsigned mask) { 219 // Set up an n x n from/to adjacency matrix of the iteration graph 220 // for the implicit loop indices i_0 .. i_n-1. 221 unsigned n = op.getNumLoops(); 222 std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false)); 223 224 // Iterate over the indexing maps of every tensor in the tensor expression. 225 for (OpOperand *t : op.getInputAndOutputOperands()) { 226 auto map = op.getTiedIndexingMap(t); 227 auto enc = getSparseTensorEncoding(t->get().getType()); 228 assert(map.getNumDims() == n); 229 // Skip dense tensor constraints when not requested. 230 if (!(mask & SortMask::kIncludeDense) && !enc) 231 continue; 232 // Each tensor expression and optional dimension ordering (row-major 233 // by default) puts an ordering constraint on the loop indices. For 234 // example, the tensor expresion A_ijk forces the ordering i < j < k 235 // on the loop indices if no explicit dimension ordering is given. 236 for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) { 237 AffineExpr f = map.getResult(perm(enc, d - 1)); 238 AffineExpr t = map.getResult(perm(enc, d)); 239 addAffineOrderings(adjM, f, t, 0); 240 } 241 // Push unrelated loops into sparse iteration space, so these 242 // will be skipped more often. 243 if (mask & SortMask::kIncludeUndef) { 244 unsigned tensor = t->getOperandNumber(); 245 for (unsigned i = 0; i < n; i++) 246 if (merger.isDim(tensor, i, Dim::kSparse)) 247 for (unsigned j = 0; j < n; j++) 248 if (merger.isDim(tensor, j, Dim::kUndef)) 249 adjM[i][j] = true; 250 } 251 } 252 253 // Topologically sort the iteration graph to determine loop order. 254 // Report failure for a cyclic iteration graph. 255 topSort.clear(); 256 topSort.reserve(n); 257 std::vector<unsigned> visit(n, 0); 258 for (unsigned i = 0; i < n; i++) 259 if (visit[i] == 0) 260 if (!topSortDFS(i, visit, topSort, adjM)) 261 return false; // cycle! 262 std::reverse(std::begin(topSort), std::end(topSort)); 263 return true; 264 } 265 266 /// Returns true when the tensor expression is admissable for codegen. 267 /// Since all sparse input tensors are admissable, we just need to check 268 /// whether the output tensor in the tensor expression codegen is admissable. 269 static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op, 270 unsigned exp) { 271 OpOperand *lhs = op.getOutputOperand(0); 272 unsigned tensor = lhs->getOperandNumber(); 273 auto enc = getSparseTensorEncoding(lhs->get().getType()); 274 // An non-annotated output tensor is assumed dense, and becomes a random 275 // access n-dim memref. Admissable since insertions cannot occur. 276 if (!enc) 277 return true; 278 // An all-dense annotated "sparse" output tensor becomes a linearized random 279 // access 1-dim memref. Also admissable since insertions cannot occur. 280 bool allDense = true; 281 unsigned numLoops = op.iterator_types().getValue().size(); 282 for (unsigned i = 0; i < numLoops; i++) 283 if (merger.isDim(tensor, i, Dim::kSparse)) { 284 allDense = false; 285 break; 286 } 287 if (allDense) 288 return true; 289 // A tensor expression with a sparse output tensor that changes its values 290 // but not its nonzero structure, an operation called "simply dynamic" in 291 // [Bik96,Ch9], is also admissable without special codegen. 292 if (merger.isConjunction(tensor, exp)) 293 return true; 294 // Reject for now since this requires changes to the nonzero structure. 295 // TODO: implement "workspaces" [Kjolstad2019] 296 return false; 297 } 298 299 //===----------------------------------------------------------------------===// 300 // Sparse compiler synthesis methods. 301 //===----------------------------------------------------------------------===// 302 303 /// Maps reduction kind to name encoding. 304 static StringRef getReductionName(Reduction kind) { 305 switch (kind) { 306 case kSum: 307 return "add"; 308 case kProduct: 309 return "mul"; 310 case kAnd: 311 return "and"; 312 case kOr: 313 return "or"; 314 case kXor: 315 return "xor"; 316 } 317 llvm_unreachable("unknown reduction kind"); 318 } 319 320 /// Maps operation to reduction. 321 static Reduction getReduction(Kind kind) { 322 switch (kind) { 323 case Kind::kAddF: 324 case Kind::kAddI: 325 case Kind::kSubF: 326 case Kind::kSubI: 327 return kSum; 328 case Kind::kMulF: 329 case Kind::kMulI: 330 return kProduct; 331 case Kind::kAndI: 332 return kAnd; 333 case Kind::kOrI: 334 return kOr; 335 case Kind::kXorI: 336 return kXor; 337 default: 338 llvm_unreachable("unexpected reduction operator"); 339 } 340 } 341 342 /// Generates an initial value for a vector reductions, following the scheme 343 /// given in Chapter 5 of "The Software Vectorization Handbook", where the 344 /// initial scalar value is correctly embedded in the vector reduction value, 345 /// and a straightforward horizontal reduction will complete the operation. 346 static Value genReductionInit(PatternRewriter &rewriter, Location loc, 347 Reduction kind, VectorType vtp, Value r) { 348 switch (kind) { 349 case kSum: 350 case kXor: { 351 // Initialize reduction vector to: | 0 | .. | 0 | r | 352 Attribute zero = rewriter.getZeroAttr(vtp); 353 Value vec = rewriter.create<ConstantOp>(loc, vtp, zero); 354 return rewriter.create<vector::InsertElementOp>(loc, r, vec, 0); 355 } 356 case kProduct: { 357 // Initialize reduction vector to: | 1 | .. | 1 | r | 358 Type etp = vtp.getElementType(); 359 Attribute one; 360 if (etp.isa<FloatType>()) 361 one = rewriter.getFloatAttr(etp, 1.0); 362 else 363 one = rewriter.getIntegerAttr(etp, 1); 364 Value vec = 365 rewriter.create<ConstantOp>(loc, vtp, DenseElementsAttr::get(vtp, one)); 366 return rewriter.create<vector::InsertElementOp>(loc, r, vec, 0); 367 } 368 case kAnd: 369 case kOr: 370 // Initialize reduction vector to: | r | .. | r | r | 371 return rewriter.create<vector::BroadcastOp>(loc, vtp, r); 372 } 373 llvm_unreachable("unknown reduction kind"); 374 } 375 376 /// Maps sparse integer option to actual integral storage type. 377 static Type genIntType(PatternRewriter &rewriter, unsigned width) { 378 if (width == 0) 379 return rewriter.getIndexType(); 380 return rewriter.getIntegerType(width); 381 } 382 383 /// Detects in-place annotation on tensor argument. 384 static bool getInPlace(Value val) { 385 if (auto arg = val.dyn_cast<BlockArgument>()) 386 if (auto funcOp = dyn_cast<FuncOp>(arg.getOwner()->getParentOp())) 387 if (auto attr = funcOp.getArgAttrOfType<BoolAttr>( 388 arg.getArgNumber(), linalg::LinalgDialect::kInplaceableAttrName)) 389 return attr.getValue(); 390 return false; 391 } 392 393 /// Generates buffer for the output tensor. Note that all sparse kernels 394 /// assume that when all elements are written to (viz. x(i) = y(i) * z(i)), 395 /// the output buffer is already initialized to all zeroes and only nonzeroes 396 /// values are computed and written out. For updates (viz. x(i) += y(i) * z(i)), 397 /// only nonzeroes values are used for the updates and no assumption on the 398 /// original contents of the output buffer is necessary.. 399 static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter, 400 linalg::GenericOp op, MemRefType denseTp, 401 ArrayRef<Value> args) { 402 Location loc = op.getLoc(); 403 Value tensor = op.getOutputOperand(0)->get(); 404 // The output tensor simply could materialize from the buffer that will 405 // be generated for the tensor present in the outs() clause. This has 406 // the major advantage that the sparse kernel only updates the nonzero 407 // positions for the output tensor. 408 if (getInPlace(tensor)) 409 return rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 410 // By default, a new buffer is allocated which is initialized to the 411 // tensor defined in the outs() clause. This is always correct but 412 // introduces a dense initialization component that may negatively 413 // impact the running complexity of the sparse kernel. If the tensor 414 // materializes within this method, we need to preserve the zero 415 // initialization assumption of all sparse output buffers. 416 if (auto init = tensor.getDefiningOp<linalg::InitTensorOp>()) { 417 Type tp = denseTp.getElementType(); 418 Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args); 419 Value zero = rewriter.create<ConstantOp>(loc, tp, rewriter.getZeroAttr(tp)); 420 rewriter.create<linalg::FillOp>(loc, zero, alloc); 421 return alloc; 422 } 423 Value init = rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 424 Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args); 425 rewriter.create<memref::CopyOp>(loc, init, alloc); 426 return alloc; 427 } 428 429 /// Local bufferization of all dense and sparse data structures. 430 /// This code enables testing the first prototype sparse compiler. 431 // TODO: replace this with a proliferated bufferization strategy 432 static bool genBuffers(Merger &merger, CodeGen &codegen, 433 PatternRewriter &rewriter, linalg::GenericOp op) { 434 Location loc = op.getLoc(); 435 assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1); 436 // For every tensor, find lower and upper bound on dimensions, set the 437 // same bounds on loop indices, and obtain dense or sparse buffer(s). 438 SmallVector<Value, 4> args; 439 for (OpOperand *t : op.getInputAndOutputOperands()) { 440 unsigned tensor = t->getOperandNumber(); 441 auto shape = op.getShape(t); 442 auto map = op.getTiedIndexingMap(t); 443 auto enc = getSparseTensorEncoding(t->get().getType()); 444 // Scan all dimensions of current tensor. 445 args.clear(); 446 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 447 AffineExpr a = map.getResult(perm(enc, d)); 448 if (a.getKind() != AffineExprKind::DimId) 449 continue; // compound 450 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 451 // Handle sparse storage schemes. 452 if (merger.isDim(tensor, idx, Dim::kSparse)) { 453 auto dynShape = {ShapedType::kDynamicSize}; 454 auto ptrTp = MemRefType::get( 455 dynShape, genIntType(rewriter, enc.getPointerBitWidth())); 456 auto indTp = MemRefType::get( 457 dynShape, genIntType(rewriter, enc.getIndexBitWidth())); 458 Value dim = rewriter.create<arith::ConstantIndexOp>(loc, d); 459 // Generate sparse primitives to obtains pointer and indices. 460 codegen.pointers[tensor][idx] = 461 rewriter.create<ToPointersOp>(loc, ptrTp, t->get(), dim); 462 codegen.indices[tensor][idx] = 463 rewriter.create<ToIndicesOp>(loc, indTp, t->get(), dim); 464 } 465 // Find upper bound in current dimension. 466 unsigned p = perm(enc, d); 467 Value up = linalg::createOrFoldDimOp(rewriter, loc, t->get(), p); 468 if (shape[p] == MemRefType::kDynamicSize) 469 args.push_back(up); 470 assert(codegen.highs[tensor][idx] == nullptr); 471 codegen.sizes[idx] = codegen.highs[tensor][idx] = up; 472 } 473 // Perform the required bufferization. Dense inputs materialize 474 // from the input tensors. Dense outputs need special handling. 475 // Sparse inputs use sparse primitives to obtain the values. 476 // We also accept in-place all-dense annotated "sparse" outputs. 477 Type elementType = getElementTypeOrSelf(t->get().getType()); 478 if (!enc) { 479 // Non-annotated dense tensors. 480 auto denseTp = MemRefType::get(shape, elementType); 481 if (tensor < op.getNumInputs()) 482 codegen.buffers[tensor] = 483 rewriter.create<memref::BufferCastOp>(loc, denseTp, t->get()); 484 else 485 codegen.buffers[tensor] = 486 genOutputBuffer(codegen, rewriter, op, denseTp, args); 487 } else { 488 // Annotated sparse tensors. 489 if (tensor == op.getNumInputs() && !getInPlace(t->get())) 490 return false; // reject output if not in-place 491 auto dynShape = {ShapedType::kDynamicSize}; 492 auto sparseTp = MemRefType::get(dynShape, elementType); 493 codegen.buffers[tensor] = 494 rewriter.create<ToValuesOp>(loc, sparseTp, t->get()); 495 } 496 } 497 return true; 498 } 499 500 /// Constructs vector type. 501 static VectorType vectorType(CodeGen &codegen, Type etp) { 502 return VectorType::get(codegen.curVecLength, etp); 503 } 504 505 /// Constructs vector type from pointer. 506 static VectorType vectorType(CodeGen &codegen, Value ptr) { 507 return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType()); 508 } 509 510 /// Constructs vector iteration mask. 511 static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter, 512 Value iv, Value lo, Value hi, Value step) { 513 Location loc = iv.getLoc(); 514 VectorType mtp = vectorType(codegen, rewriter.getIntegerType(1)); 515 // Special case if the vector length evenly divides the trip count (for 516 // example, "for i = 0, 128, 16"). A constant all-true mask is generated 517 // so that all subsequent masked memory operations are immediately folded 518 // into unconditional memory operations. 519 IntegerAttr loInt, hiInt, stepInt; 520 if (matchPattern(lo, m_Constant(&loInt)) && 521 matchPattern(hi, m_Constant(&hiInt)) && 522 matchPattern(step, m_Constant(&stepInt))) { 523 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) 524 return rewriter.create<vector::BroadcastOp>( 525 loc, mtp, rewriter.create<arith::ConstantIntOp>(loc, 1, 1)); 526 } 527 // Otherwise, generate a vector mask that avoids overrunning the upperbound 528 // during vector execution. Here we rely on subsequent loop optimizations to 529 // avoid executing the mask in all iterations, for example, by splitting the 530 // loop into an unconditional vector loop and a scalar cleanup loop. 531 auto minMap = AffineMap::get( 532 /*dimCount=*/2, /*symbolCount=*/1, 533 {rewriter.getAffineSymbolExpr(0), 534 rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)}, 535 rewriter.getContext()); 536 Value end = 537 rewriter.createOrFold<AffineMinOp>(loc, minMap, ValueRange{hi, iv, step}); 538 return rewriter.create<vector::CreateMaskOp>(loc, mtp, end); 539 } 540 541 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi]. 542 static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter, 543 Value ptr, ArrayRef<Value> args) { 544 Location loc = ptr.getLoc(); 545 VectorType vtp = vectorType(codegen, ptr); 546 Value pass = 547 rewriter.create<arith::ConstantOp>(loc, vtp, rewriter.getZeroAttr(vtp)); 548 if (args.back().getType().isa<VectorType>()) { 549 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 550 Value indexVec = args.back(); 551 scalarArgs.back() = rewriter.create<arith::ConstantIndexOp>(loc, 0); 552 return rewriter.create<vector::GatherOp>( 553 loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass); 554 } 555 return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args, 556 codegen.curVecMask, pass); 557 } 558 559 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs. 560 static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter, 561 Value rhs, Value ptr, ArrayRef<Value> args) { 562 Location loc = ptr.getLoc(); 563 if (args.back().getType().isa<VectorType>()) { 564 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 565 Value indexVec = args.back(); 566 scalarArgs.back() = rewriter.create<arith::ConstantIndexOp>(loc, 0); 567 rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, 568 codegen.curVecMask, rhs); 569 return; 570 } 571 rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask, 572 rhs); 573 } 574 575 /// Generates a vectorized invariant. Here we rely on subsequent loop 576 /// optimizations to hoist the invariant broadcast out of the vector loop. 577 static Value genVectorInvariantValue(CodeGen &codegen, 578 PatternRewriter &rewriter, Value val) { 579 VectorType vtp = vectorType(codegen, val.getType()); 580 return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val); 581 } 582 583 /// Generates an affine expression. 584 // 585 // TODO: generalize for sparse tensor subscripts 586 // 587 static Value genAffine(CodeGen &codegen, PatternRewriter &rewriter, 588 AffineExpr a, Location loc) { 589 switch (a.getKind()) { 590 case AffineExprKind::DimId: { 591 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 592 return codegen.loops[idx]; // universal dense index 593 } 594 case AffineExprKind::Add: { 595 auto binOp = a.cast<AffineBinaryOpExpr>(); 596 return rewriter.create<arith::AddIOp>( 597 loc, genAffine(codegen, rewriter, binOp.getLHS(), loc), 598 genAffine(codegen, rewriter, binOp.getRHS(), loc)); 599 } 600 case AffineExprKind::Mul: { 601 auto binOp = a.cast<AffineBinaryOpExpr>(); 602 return rewriter.create<arith::MulIOp>( 603 loc, genAffine(codegen, rewriter, binOp.getLHS(), loc), 604 genAffine(codegen, rewriter, binOp.getRHS(), loc)); 605 } 606 case AffineExprKind::Constant: { 607 int64_t c = a.cast<AffineConstantExpr>().getValue(); 608 return rewriter.create<arith::ConstantIndexOp>(loc, c); 609 } 610 default: 611 llvm_unreachable("unexpected affine subscript"); 612 } 613 } 614 615 /// Generates subscript for load/store on a dense or sparse tensor. 616 static Value genSubscript(CodeGen &codegen, PatternRewriter &rewriter, 617 linalg::GenericOp op, OpOperand *t, 618 SmallVector<Value, 4> &args) { 619 unsigned tensor = t->getOperandNumber(); 620 auto map = op.getTiedIndexingMap(t); 621 auto enc = getSparseTensorEncoding(t->get().getType()); 622 unsigned rank = map.getNumResults(); 623 if (enc) { 624 // Note that currently, all sparse subscripts are simple. 625 // TODO: accept affine too? 626 unsigned idx = map.getDimPosition(perm(enc, rank - 1)); 627 assert(codegen.pidxs[tensor][idx] != nullptr); 628 args.push_back(codegen.pidxs[tensor][idx]); // position index 629 } else { 630 for (unsigned d = 0; d < rank; d++) { 631 AffineExpr a = map.getResult(perm(enc, d)); 632 args.push_back(genAffine(codegen, rewriter, a, op.getLoc())); 633 } 634 } 635 return codegen.buffers[tensor]; 636 } 637 638 /// Generates a load on a dense or sparse tensor. 639 static Value genTensorLoad(Merger &merger, CodeGen &codegen, 640 PatternRewriter &rewriter, linalg::GenericOp op, 641 unsigned exp) { 642 // Test if the load was hoisted to a higher loop nest. 643 Value val = merger.exp(exp).val; 644 if (val) { 645 if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>()) 646 return genVectorInvariantValue(codegen, rewriter, val); 647 return val; 648 } 649 // Actual load. 650 SmallVector<Value, 4> args; 651 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 652 Value ptr = genSubscript(codegen, rewriter, op, t, args); 653 if (codegen.curVecLength > 1) 654 return genVectorLoad(codegen, rewriter, ptr, args); 655 return rewriter.create<memref::LoadOp>(op.getLoc(), ptr, args); 656 } 657 658 /// Generates a store on a dense or sparse tensor. 659 static void genTensorStore(Merger &merger, CodeGen &codegen, 660 PatternRewriter &rewriter, linalg::GenericOp op, 661 Value rhs) { 662 // Test if this is a scalarized reduction. 663 if (codegen.redVal) { 664 if (codegen.curVecLength > 1) 665 rhs = rewriter.create<SelectOp>(op.getLoc(), codegen.curVecMask, rhs, 666 codegen.redVal); 667 codegen.redVal = rhs; 668 return; 669 } 670 // Actual store. 671 SmallVector<Value, 4> args; 672 OpOperand *t = op.getOutputOperand(0); 673 Value ptr = genSubscript(codegen, rewriter, op, t, args); 674 if (codegen.curVecLength > 1) 675 genVectorStore(codegen, rewriter, rhs, ptr, args); 676 else 677 rewriter.create<memref::StoreOp>(op.getLoc(), rhs, ptr, args); 678 } 679 680 /// Generates a pointer/index load from the sparse storage scheme. Narrower 681 /// data types need to be zero extended before casting the value into the 682 /// index type used for looping and indexing. 683 static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc, 684 Value ptr, Value s) { 685 // See https://llvm.org/docs/GetElementPtr.html for some background on 686 // the complications described below. 687 if (codegen.curVecLength > 1) { 688 // Since the index vector is used in a subsequent gather/scatter operations, 689 // which effectively defines an unsigned pointer + signed index, we must 690 // zero extend the vector to an index width. For 8-bit and 16-bit values, 691 // an 32-bit index width suffices. For 32-bit values, zero extending the 692 // elements into 64-bit loses some performance since the 32-bit indexed 693 // gather/scatter is more efficient than the 64-bit index variant (if the 694 // negative 32-bit index space is unused, the enableSIMDIndex32 flag can 695 // preserve this performance). For 64-bit values, there is no good way 696 // to state that the indices are unsigned, with creates the potential of 697 // incorrect address calculations in the unlikely case we need such 698 // extremely large offsets. 699 Type etp = ptr.getType().cast<MemRefType>().getElementType(); 700 Value vload = genVectorLoad(codegen, rewriter, ptr, {s}); 701 if (!etp.isa<IndexType>()) { 702 if (etp.getIntOrFloatBitWidth() < 32) 703 vload = rewriter.create<arith::ExtUIOp>( 704 loc, vload, vectorType(codegen, rewriter.getIntegerType(32))); 705 else if (etp.getIntOrFloatBitWidth() < 64 && 706 !codegen.options.enableSIMDIndex32) 707 vload = rewriter.create<arith::ExtUIOp>( 708 loc, vload, vectorType(codegen, rewriter.getIntegerType(64))); 709 } 710 return vload; 711 } 712 // For the scalar case, we simply zero extend narrower indices into 64-bit 713 // values before casting to index without a performance penalty. Here too, 714 // however, indices that already are 64-bit, in theory, cannot express the 715 // full range as explained above. 716 Value load = rewriter.create<memref::LoadOp>(loc, ptr, s); 717 if (!load.getType().isa<IndexType>()) { 718 if (load.getType().getIntOrFloatBitWidth() < 64) 719 load = rewriter.create<arith::ExtUIOp>(loc, load, 720 rewriter.getIntegerType(64)); 721 load = 722 rewriter.create<arith::IndexCastOp>(loc, load, rewriter.getIndexType()); 723 } 724 return load; 725 } 726 727 /// Generates an invariant value. 728 static Value genInvariantValue(Merger &merger, CodeGen &codegen, 729 PatternRewriter &rewriter, unsigned exp) { 730 Value val = merger.exp(exp).val; 731 if (codegen.curVecLength > 1) 732 return genVectorInvariantValue(codegen, rewriter, val); 733 return val; 734 } 735 736 /// Generates an address computation "sz * p + i". 737 static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter, 738 Location loc, Value size, Value p, Value i) { 739 Value mul = rewriter.create<arith::MulIOp>(loc, size, p); 740 if (auto vtp = i.getType().dyn_cast<VectorType>()) { 741 Value inv = 742 rewriter.create<arith::IndexCastOp>(loc, mul, vtp.getElementType()); 743 mul = genVectorInvariantValue(codegen, rewriter, inv); 744 } 745 return rewriter.create<arith::AddIOp>(loc, mul, i); 746 } 747 748 /// Generates start of a reduction. 749 static Value genReductionStart(Merger &merger, CodeGen &codegen, 750 PatternRewriter &rewriter, 751 linalg::GenericOp op) { 752 if (codegen.redVal) 753 return codegen.redVal; // chained with previous for-loop 754 // Generate vector or scalar start of a reduction. 755 unsigned vl = codegen.curVecLength; 756 if (vl > 1) { 757 VectorType vtp = vectorType(codegen, codegen.buffers[codegen.redExp]); 758 assert(!merger.exp(codegen.redExp).val); 759 codegen.curVecLength = 1; 760 Value load = genTensorLoad(merger, codegen, rewriter, op, codegen.redExp); 761 codegen.curVecLength = vl; 762 return genReductionInit(rewriter, op.getLoc(), codegen.redKind, vtp, load); 763 } 764 return genTensorLoad(merger, codegen, rewriter, op, codegen.redExp); 765 } 766 767 /// Generates end of a reduction. 768 static void genReductionEnd(Merger &merger, CodeGen &codegen, 769 PatternRewriter &rewriter, linalg::GenericOp op) { 770 Value red = codegen.redVal; 771 if (!red) 772 return; 773 assert(codegen.curVecLength == 1); 774 codegen.redVal = merger.exp(codegen.redExp).val = Value(); // end chain 775 // Generate vector or scalar end of a reduction. 776 if (auto vtp = red.getType().dyn_cast<VectorType>()) { 777 StringRef name = getReductionName(codegen.redKind); 778 StringAttr kind = rewriter.getStringAttr(name); 779 red = rewriter.create<vector::ReductionOp>( 780 op.getLoc(), vtp.getElementType(), kind, red, ValueRange{}); 781 } 782 genTensorStore(merger, codegen, rewriter, op, red); 783 } 784 785 /// Recursively generates tensor expression. 786 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 787 linalg::GenericOp op, unsigned exp) { 788 Location loc = op.getLoc(); 789 if (exp == -1u) 790 return Value(); 791 if (merger.exp(exp).kind == Kind::kTensor) 792 return genTensorLoad(merger, codegen, rewriter, op, exp); 793 if (merger.exp(exp).kind == Kind::kInvariant) 794 return genInvariantValue(merger, codegen, rewriter, exp); 795 Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0); 796 Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1); 797 return merger.buildExp(rewriter, loc, exp, v0, v1); 798 } 799 800 /// Determines if affine expression is invariant. 801 static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a, 802 unsigned ldx, bool &atLevel) { 803 switch (a.getKind()) { 804 case AffineExprKind::DimId: { 805 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 806 if (idx == ldx) 807 atLevel = true; 808 return codegen.loops[idx] != nullptr; // no longer in play? 809 } 810 case AffineExprKind::Add: 811 case AffineExprKind::Mul: { 812 auto binOp = a.cast<AffineBinaryOpExpr>(); 813 return isInvariantAffine(codegen, binOp.getLHS(), ldx, atLevel) && 814 isInvariantAffine(codegen, binOp.getRHS(), ldx, atLevel); 815 } 816 default: 817 return true; 818 } 819 } 820 821 /// Hoists loop invariant tensor loads for which indices have been exhausted. 822 static void genInvariants(Merger &merger, CodeGen &codegen, 823 PatternRewriter &rewriter, linalg::GenericOp op, 824 unsigned exp, unsigned ldx, bool hoist, 825 Kind last = Kind::kTensor) { 826 if (exp == -1u) 827 return; 828 if (merger.exp(exp).kind == Kind::kTensor) { 829 // Inspect tensor indices. 830 bool atLevel = ldx == -1u; 831 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 832 auto map = op.getTiedIndexingMap(t); 833 auto enc = getSparseTensorEncoding(t->get().getType()); 834 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 835 AffineExpr a = map.getResult(perm(enc, d)); 836 if (!isInvariantAffine(codegen, a, ldx, atLevel)) 837 return; // still in play 838 } 839 // All exhausted at this level (atLevel denotes exactly at this level). 840 OpOperand *lhs = op.getOutputOperand(0); 841 if (lhs == t) { 842 codegen.redExp = hoist ? exp : -1u; 843 codegen.redKind = getReduction(last); 844 } else if (atLevel) { 845 merger.exp(exp).val = 846 hoist ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value(); 847 } 848 } else if (merger.exp(exp).kind != Kind::kInvariant) { 849 // Traverse into the binary operations. Note that we only hoist 850 // tensor loads, since subsequent MLIR/LLVM passes know how to 851 // deal with all other kinds of derived loop invariants. 852 Kind last = merger.exp(exp).kind; 853 unsigned e0 = merger.exp(exp).children.e0; 854 unsigned e1 = merger.exp(exp).children.e1; 855 genInvariants(merger, codegen, rewriter, op, e0, ldx, hoist, last); 856 genInvariants(merger, codegen, rewriter, op, e1, ldx, hoist, last); 857 } 858 } 859 860 /// Generates initialization code for the subsequent loop sequence at 861 /// current index level. Returns true if the loop sequence needs to 862 /// maintain the universal index. 863 static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 864 linalg::GenericOp op, std::vector<unsigned> &topSort, 865 unsigned at, llvm::BitVector &inits) { 866 bool needsUniv = false; 867 Location loc = op.getLoc(); 868 unsigned idx = topSort[at]; 869 870 // Initialize sparse positions. 871 for (unsigned b = 0, be = inits.size(); b < be; b++) { 872 if (inits[b]) { 873 unsigned tensor = merger.tensor(b); 874 assert(idx == merger.index(b)); 875 if (merger.isDim(b, Dim::kSparse)) { 876 // Initialize sparse index. 877 unsigned pat = at; 878 for (; pat != 0; pat--) { 879 if (codegen.pidxs[tensor][topSort[pat - 1]]) 880 break; 881 } 882 Value ptr = codegen.pointers[tensor][idx]; 883 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 884 Value p0 = (pat == 0) ? rewriter.create<arith::ConstantIndexOp>(loc, 0) 885 : codegen.pidxs[tensor][topSort[pat - 1]]; 886 codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0); 887 Value p1 = rewriter.create<arith::AddIOp>(loc, p0, one); 888 codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1); 889 } else { 890 // Dense index still in play. 891 needsUniv = true; 892 } 893 } 894 } 895 896 // Initialize the universal dense index. 897 codegen.loops[idx] = rewriter.create<arith::ConstantIndexOp>(loc, 0); 898 return needsUniv; 899 } 900 901 /// Returns vectorization strategy. Any implicit inner loop in the Linalg 902 /// operation is a candidate. Whether it is actually converted to SIMD code 903 /// depends on the requested strategy. 904 static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) { 905 switch (codegen.options.vectorizationStrategy) { 906 case SparseVectorizationStrategy::kNone: 907 return false; 908 case SparseVectorizationStrategy::kDenseInnerLoop: 909 return isInner && !isSparse; 910 case SparseVectorizationStrategy::kAnyStorageInnerLoop: 911 return isInner; 912 } 913 llvm_unreachable("unexpected vectorization strategy"); 914 } 915 916 /// Returns parallelization strategy. Any implicit loop in the Linalg operation 917 /// that is marked "parallel" is a candidate. Whether it is actually converted 918 /// to a parallel operation depends on the requested strategy. 919 static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction, 920 bool isSparse, bool isVector) { 921 switch (codegen.options.parallelizationStrategy) { 922 case SparseParallelizationStrategy::kNone: 923 return false; 924 case SparseParallelizationStrategy::kDenseOuterLoop: 925 return isOuter && !isSparse && !isReduction && !isVector; 926 case SparseParallelizationStrategy::kAnyStorageOuterLoop: 927 return isOuter && !isReduction && !isVector; 928 case SparseParallelizationStrategy::kDenseAnyLoop: 929 return !isSparse && !isReduction && !isVector; 930 case SparseParallelizationStrategy::kAnyStorageAnyLoop: 931 return !isReduction && !isVector; 932 } 933 llvm_unreachable("unexpected parallelization strategy"); 934 } 935 936 /// Checks unit stride for dense tensors. The iteration graph may have ignored 937 /// dense access patterns in order to avoid cycles (sparse access patterns are 938 /// always placed innermost), but that means dense access has become strided. 939 /// This prevents effective vectorization. 940 static bool denseUnitStrides(Merger &merger, linalg::GenericOp op, 941 unsigned idx) { 942 for (OpOperand *t : op.getInputAndOutputOperands()) { 943 if (!getSparseTensorEncoding(t->get().getType())) { 944 auto map = op.getTiedIndexingMap(t); 945 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 946 AffineExpr a = map.getResult(d); 947 // Report non-unit stride if innermost index appears at an outer 948 // dimension (true non-unit stride) or if the innermost index appears 949 // in a compound subscript in the innermost dimension. Even if the 950 // latter is unit stride, it does not play well with scatter/gather. 951 if (a.isFunctionOfDim(idx) && 952 ((d != rank - 1) || (a.getKind() != AffineExprKind::DimId))) 953 return false; 954 } 955 } 956 } 957 return true; 958 } 959 960 /// Generates a for-loop on a single index. 961 static Operation *genFor(Merger &merger, CodeGen &codegen, 962 PatternRewriter &rewriter, linalg::GenericOp op, 963 bool isOuter, bool isInner, unsigned idx, 964 llvm::BitVector &indices) { 965 unsigned fb = indices.find_first(); 966 unsigned tensor = merger.tensor(fb); 967 assert(idx == merger.index(fb)); 968 auto iteratorTypes = op.iterator_types().getValue(); 969 bool isReduction = isReductionIterator(iteratorTypes[idx]); 970 bool isSparse = merger.isDim(fb, Dim::kSparse); 971 bool isVector = isVectorFor(codegen, isInner, isSparse) && 972 denseUnitStrides(merger, op, idx); 973 bool isParallel = 974 isParallelFor(codegen, isOuter, isReduction, isSparse, isVector); 975 976 // Prepare vector length. 977 if (isVector) 978 codegen.curVecLength = codegen.options.vectorLength; 979 980 // Loop bounds and increment. 981 Location loc = op.getLoc(); 982 Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx]; 983 Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx]; 984 Value step = 985 rewriter.create<arith::ConstantIndexOp>(loc, codegen.curVecLength); 986 987 // Emit a parallel loop. 988 if (isParallel) { 989 assert(!isVector); 990 scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step); 991 if (isSparse) 992 codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0]; 993 else 994 codegen.loops[idx] = parOp.getInductionVars()[0]; 995 rewriter.setInsertionPointToStart(parOp.getBody()); 996 return parOp; 997 } 998 999 // Emit a sequential loop, potentially with a scalarized reduction. 1000 bool scalarRed = isInner && codegen.redExp != -1u; 1001 SmallVector<Value, 4> operands; 1002 if (scalarRed) { 1003 Value load = genReductionStart(merger, codegen, rewriter, op); 1004 operands.push_back(load); 1005 } 1006 scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands); 1007 if (scalarRed) { 1008 codegen.redVal = merger.exp(codegen.redExp).val = 1009 forOp.getRegionIterArgs().front(); 1010 } 1011 // Assign induction variable to sparse or dense index. 1012 Value iv = forOp.getInductionVar(); 1013 if (isSparse) 1014 codegen.pidxs[tensor][idx] = iv; 1015 else 1016 codegen.loops[idx] = iv; 1017 rewriter.setInsertionPointToStart(forOp.getBody()); 1018 // Share vector iteration mask between all subsequent loads/stores. 1019 if (isVector) 1020 codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step); 1021 return forOp; 1022 } 1023 1024 /// Emit a while-loop for co-iteration over multiple indices. 1025 static Operation *genWhile(Merger &merger, CodeGen &codegen, 1026 PatternRewriter &rewriter, linalg::GenericOp op, 1027 unsigned idx, bool needsUniv, 1028 llvm::BitVector &indices) { 1029 SmallVector<Type, 4> types; 1030 SmallVector<Value, 4> operands; 1031 // Construct the while-loop with a parameter for each index. 1032 Type indexType = rewriter.getIndexType(); 1033 for (unsigned b = 0, be = indices.size(); b < be; b++) { 1034 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1035 unsigned tensor = merger.tensor(b); 1036 assert(idx == merger.index(b)); 1037 types.push_back(indexType); 1038 assert(codegen.pidxs[tensor][idx].getType().isa<IndexType>() && 1039 "type mismatch for sparse index"); 1040 operands.push_back(codegen.pidxs[tensor][idx]); 1041 } 1042 } 1043 if (needsUniv) { 1044 types.push_back(indexType); 1045 assert(codegen.loops[idx].getType().isa<IndexType>() && 1046 "type mismatch for universal index"); 1047 operands.push_back(codegen.loops[idx]); 1048 } 1049 Location loc = op.getLoc(); 1050 scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands); 1051 Block *before = rewriter.createBlock(&whileOp.before(), {}, types); 1052 Block *after = rewriter.createBlock(&whileOp.after(), {}, types); 1053 1054 // Build the "before" region, which effectively consists 1055 // of a conjunction of "i < upper" tests on all induction. 1056 rewriter.setInsertionPointToStart(&whileOp.before().front()); 1057 Value cond; 1058 unsigned o = 0; 1059 for (unsigned b = 0, be = indices.size(); b < be; b++) { 1060 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1061 unsigned tensor = merger.tensor(b); 1062 assert(idx == merger.index(b)); 1063 Value op1 = before->getArgument(o); 1064 Value op2 = codegen.highs[tensor][idx]; 1065 Value opc = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 1066 op1, op2); 1067 cond = cond ? rewriter.create<arith::AndIOp>(loc, cond, opc) : opc; 1068 codegen.pidxs[tensor][idx] = after->getArgument(o++); 1069 } 1070 } 1071 if (needsUniv) 1072 codegen.loops[idx] = after->getArgument(o++); 1073 assert(o == operands.size()); 1074 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 1075 rewriter.setInsertionPointToStart(&whileOp.after().front()); 1076 return whileOp; 1077 } 1078 1079 /// Generates a for-loop or a while-loop, depending on whether it implements 1080 /// singleton iteration or co-iteration over the given conjunction. 1081 static Operation *genLoop(Merger &merger, CodeGen &codegen, 1082 PatternRewriter &rewriter, linalg::GenericOp op, 1083 std::vector<unsigned> &topSort, unsigned at, 1084 bool needsUniv, llvm::BitVector &indices) { 1085 unsigned idx = topSort[at]; 1086 if (indices.count() == 1) { 1087 bool isOuter = at == 0; 1088 bool isInner = at == topSort.size() - 1; 1089 return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx, 1090 indices); 1091 } 1092 genReductionEnd(merger, codegen, rewriter, op); // cannot chain 1093 return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices); 1094 } 1095 1096 /// Generates the local variables for this loop, consisting of the sparse 1097 /// indices, restored universal dense index, and dense positions. 1098 static void genLocals(Merger &merger, CodeGen &codegen, 1099 PatternRewriter &rewriter, linalg::GenericOp op, 1100 std::vector<unsigned> &topSort, unsigned at, 1101 bool needsUniv, llvm::BitVector &locals) { 1102 Location loc = op.getLoc(); 1103 unsigned idx = topSort[at]; 1104 1105 // Initialize sparse indices. 1106 Value min; 1107 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1108 if (locals[b] && merger.isDim(b, Dim::kSparse)) { 1109 unsigned tensor = merger.tensor(b); 1110 assert(idx == merger.index(b)); 1111 Value ptr = codegen.indices[tensor][idx]; 1112 Value s = codegen.pidxs[tensor][idx]; 1113 Value load = genLoad(codegen, rewriter, loc, ptr, s); 1114 codegen.idxs[tensor][idx] = load; 1115 if (!needsUniv) { 1116 if (min) { 1117 Value cmp = rewriter.create<arith::CmpIOp>( 1118 loc, arith::CmpIPredicate::ult, load, min); 1119 min = rewriter.create<SelectOp>(loc, cmp, load, min); 1120 } else { 1121 min = load; 1122 } 1123 } 1124 } 1125 } 1126 1127 // Merge dense universal index over minimum. 1128 if (min) { 1129 assert(!needsUniv); 1130 codegen.loops[idx] = min; 1131 } 1132 1133 // Initialize dense positions. Note that we generate dense indices of the 1134 // output tensor unconditionally, since they may not appear in the lattice, 1135 // but may be needed for linearized codegen. 1136 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1137 if ((locals[b] || merger.isOutTensor(b, idx)) && 1138 merger.isDim(b, Dim::kDense)) { 1139 unsigned tensor = merger.tensor(b); 1140 assert(idx == merger.index(b)); 1141 unsigned pat = at; 1142 for (; pat != 0; pat--) 1143 if (codegen.pidxs[tensor][topSort[pat - 1]]) 1144 break; 1145 Value p = (pat == 0) ? rewriter.create<arith::ConstantIndexOp>(loc, 0) 1146 : codegen.pidxs[tensor][topSort[pat - 1]]; 1147 codegen.pidxs[tensor][idx] = genAddress( 1148 codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]); 1149 } 1150 } 1151 } 1152 1153 /// Generates the induction structure for a while-loop. 1154 static void genWhileInduction(Merger &merger, CodeGen &codegen, 1155 PatternRewriter &rewriter, linalg::GenericOp op, 1156 unsigned idx, bool needsUniv, 1157 llvm::BitVector &induction, ResultRange results) { 1158 Location loc = op.getLoc(); 1159 unsigned o = 0; 1160 SmallVector<Value, 4> operands; 1161 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 1162 for (unsigned b = 0, be = induction.size(); b < be; b++) { 1163 if (induction[b] && merger.isDim(b, Dim::kSparse)) { 1164 unsigned tensor = merger.tensor(b); 1165 assert(idx == merger.index(b)); 1166 Value op1 = codegen.idxs[tensor][idx]; 1167 Value op2 = codegen.loops[idx]; 1168 Value op3 = codegen.pidxs[tensor][idx]; 1169 Value cmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 1170 op1, op2); 1171 Value add = rewriter.create<arith::AddIOp>(loc, op3, one); 1172 operands.push_back(rewriter.create<SelectOp>(loc, cmp, add, op3)); 1173 codegen.pidxs[tensor][idx] = results[o++]; 1174 } 1175 } 1176 if (needsUniv) { 1177 operands.push_back( 1178 rewriter.create<arith::AddIOp>(loc, codegen.loops[idx], one)); 1179 codegen.loops[idx] = results[o++]; 1180 } 1181 assert(o == operands.size()); 1182 rewriter.create<scf::YieldOp>(loc, operands); 1183 } 1184 1185 /// Generates a single if-statement within a while-loop. 1186 static scf::IfOp genIf(Merger &merger, CodeGen &codegen, 1187 PatternRewriter &rewriter, linalg::GenericOp op, 1188 unsigned idx, llvm::BitVector &conditions) { 1189 Location loc = op.getLoc(); 1190 Value cond; 1191 for (unsigned b = 0, be = conditions.size(); b < be; b++) { 1192 if (conditions[b]) { 1193 unsigned tensor = merger.tensor(b); 1194 assert(idx == merger.index(b)); 1195 Value clause; 1196 if (merger.isDim(b, Dim::kSparse)) { 1197 Value op1 = codegen.idxs[tensor][idx]; 1198 Value op2 = codegen.loops[idx]; 1199 clause = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 1200 op1, op2); 1201 } else { 1202 clause = rewriter.create<arith::ConstantIntOp>(loc, 1, 1); // true 1203 } 1204 cond = cond ? rewriter.create<arith::AndIOp>(loc, cond, clause) : clause; 1205 } 1206 } 1207 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ true); 1208 rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); 1209 return ifOp; 1210 } 1211 1212 /// Recursively generates code while computing iteration lattices in order 1213 /// to manage the complexity of implementing co-iteration over unions 1214 /// and intersections of sparse iterations spaces. 1215 static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1216 linalg::GenericOp op, std::vector<unsigned> &topSort, 1217 unsigned exp, unsigned at) { 1218 // At each leaf, assign remaining tensor (sub)expression to output tensor. 1219 if (at == topSort.size()) { 1220 Value rhs = genExp(merger, codegen, rewriter, op, exp); 1221 genTensorStore(merger, codegen, rewriter, op, rhs); 1222 return; 1223 } 1224 assert(codegen.curVecLength == 1); 1225 1226 // Construct iteration lattices for current loop index, with L0 at top. 1227 // Then emit initialization code for the loop sequence at this level. 1228 // We maintain the universal dense index if dense indices are still 1229 // in play for a non-singleton loop sequence. 1230 Location loc = op.getLoc(); 1231 unsigned idx = topSort[at]; 1232 unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx)); 1233 unsigned lsize = merger.set(lts).size(); 1234 assert(lsize != 0); 1235 unsigned l0 = merger.set(lts)[0]; 1236 unsigned ldx = at == 0 ? -1u : topSort[at - 1]; 1237 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/true); 1238 bool needsUniv = false; 1239 if (genInit(merger, codegen, rewriter, op, topSort, at, 1240 merger.lat(l0).bits)) { 1241 // Maintain the universal index only if it is actually 1242 // consumed by a subsequent lattice point. 1243 for (unsigned i = 1; i < lsize; i++) { 1244 unsigned li = merger.set(lts)[i]; 1245 if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) { 1246 needsUniv = true; 1247 break; 1248 } 1249 } 1250 } 1251 1252 // Emit a loop for every lattice point L0 >= Li. 1253 for (unsigned i = 0; i < lsize; i++) { 1254 unsigned li = merger.set(lts)[i]; 1255 1256 // Emit loop. 1257 codegen.curVecLength = 1; 1258 llvm::BitVector indices = merger.lat(li).simple; 1259 Operation *loop = 1260 genLoop(merger, codegen, rewriter, op, topSort, at, needsUniv, indices); 1261 genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, 1262 merger.lat(li).bits); 1263 1264 // Visit all lattices points with Li >= Lj to generate the 1265 // loop-body, possibly with if statements for coiteration. 1266 bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr; 1267 for (unsigned j = 0; j < lsize; j++) { 1268 unsigned lj = merger.set(lts)[j]; 1269 unsigned ej = merger.lat(lj).exp; 1270 if (li == lj || merger.latGT(li, lj)) { 1271 // Recurse into body of each branch. 1272 if (isWhile) { 1273 scf::IfOp ifOp = 1274 genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple); 1275 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1276 rewriter.setInsertionPointToStart(&ifOp.elseRegion().front()); 1277 } else { 1278 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1279 } 1280 } 1281 } 1282 1283 // Wrap-up induction and restore insertion point. 1284 if (isWhile) { 1285 scf::WhileOp whileOp = cast<scf::WhileOp>(loop); 1286 rewriter.setInsertionPointToEnd(&whileOp.after().front()); 1287 genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv, 1288 merger.lat(li).bits, whileOp.results()); 1289 } else { 1290 needsUniv = false; 1291 if (codegen.redVal) { 1292 rewriter.create<scf::YieldOp>(loc, codegen.redVal); 1293 codegen.redVal = loop->getResult(0); 1294 } 1295 } 1296 rewriter.setInsertionPointAfter(loop); 1297 } 1298 1299 // Wrap-up loop sequence. 1300 codegen.curVecLength = 1; 1301 genReductionEnd(merger, codegen, rewriter, op); 1302 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/false); 1303 codegen.loops[idx] = Value(); 1304 } 1305 1306 /// Converts the result computed by the sparse kernel into the required form. 1307 static void genResult(Merger &merger, CodeGen &codegen, 1308 PatternRewriter &rewriter, linalg::GenericOp op) { 1309 Location loc = op.getLoc(); 1310 OpOperand *lhs = op.getOutputOperand(0); 1311 Type resType = lhs->get().getType(); 1312 unsigned tensor = lhs->getOperandNumber(); 1313 auto map = op.getTiedIndexingMap(lhs); 1314 auto enc = getSparseTensorEncoding(resType); 1315 Value result = codegen.buffers.back(); // value array 1316 if (enc) { 1317 // The sparse annotation unambigiously defines the arrays needed 1318 // to "reconstruct" the sparse tensor from the storage scheme 1319 // (even though lowering should never need this eventually). 1320 SmallVector<Value, 4> args; 1321 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 1322 AffineExpr a = map.getResult(perm(enc, d)); 1323 if (a.getKind() != AffineExprKind::DimId) 1324 continue; // compound 1325 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 1326 if (merger.isDim(tensor, idx, Dim::kSparse)) { 1327 args.push_back(codegen.pointers[tensor][idx]); 1328 args.push_back(codegen.indices[tensor][idx]); 1329 } 1330 } 1331 args.push_back(result); 1332 result = rewriter.create<ToTensorOp>(loc, resType, args); 1333 } else { 1334 // To "reconstruct" an non-annotated tensor, sipmly load it 1335 // from the bufferized value. 1336 result = rewriter.create<memref::TensorLoadOp>(loc, resType, result); 1337 } 1338 rewriter.replaceOp(op, result); 1339 } 1340 1341 //===----------------------------------------------------------------------===// 1342 // Sparse compiler rewriting methods. 1343 //===----------------------------------------------------------------------===// 1344 1345 namespace { 1346 1347 /// Sparse rewriting rule for generic Lingalg operation. 1348 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> { 1349 public: 1350 GenericOpSparsifier(MLIRContext *context, SparsificationOptions o) 1351 : OpRewritePattern<linalg::GenericOp>(context), options(o) {} 1352 1353 LogicalResult matchAndRewrite(linalg::GenericOp op, 1354 PatternRewriter &rewriter) const override { 1355 // Detects sparse annotations and translate the per-dimension sparsity 1356 // information for all tensors to loop indices in the kernel. 1357 assert(op.getNumOutputs() == 1); 1358 unsigned numTensors = op.getNumInputsAndOutputs(); 1359 unsigned numLoops = op.iterator_types().getValue().size(); 1360 Merger merger(numTensors, numLoops); 1361 if (!findSparseAnnotations(merger, op)) 1362 return failure(); 1363 1364 // Computes a topologically sorted iteration graph to ensure 1365 // tensors are visited in natural index order. Fails on cycles. 1366 // This assumes that higher-level passes have already put the 1367 // tensors in each tensor expression in a feasible order. 1368 std::vector<unsigned> topSort; 1369 if (!computeIterationGraph(merger, op, topSort, 1370 SortMask::kIncludeUndef | 1371 SortMask::kIncludeDense) && 1372 !computeIterationGraph(merger, op, topSort, SortMask::kIncludeUndef) && 1373 !computeIterationGraph(merger, op, topSort, SortMask::kIncludeDense) && 1374 !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly)) 1375 return failure(); 1376 1377 // Builds the tensor expression for the Linalg operation in SSA form. 1378 Optional<unsigned> exp = merger.buildTensorExpFromLinalg(op); 1379 if (!exp.hasValue()) 1380 return failure(); 1381 1382 // Rejects an inadmissable tensor expression. 1383 if (!isAdmissableTensorExp(merger, op, exp.getValue())) 1384 return failure(); 1385 1386 // Recursively generates code. 1387 CodeGen codegen(options, numTensors, numLoops); 1388 if (!genBuffers(merger, codegen, rewriter, op)) 1389 return failure(); // could not bufferize 1390 genStmt(merger, codegen, rewriter, op, topSort, exp.getValue(), 0); 1391 genResult(merger, codegen, rewriter, op); 1392 return success(); 1393 } 1394 1395 private: 1396 /// Options to control sparse code generation. 1397 SparsificationOptions options; 1398 }; 1399 1400 } // namespace 1401 1402 /// Populates the given patterns list with rewriting rules required for 1403 /// the sparsification of linear algebra operations. 1404 void mlir::populateSparsificationPatterns( 1405 RewritePatternSet &patterns, const SparsificationOptions &options) { 1406 patterns.add<GenericOpSparsifier>(patterns.getContext(), options); 1407 } 1408