1 //===- Sparsification.cpp - Implementation of sparsification --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements converting sparse tensor types to actual sparse code. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 15 #include "mlir/Dialect/Linalg/Utils/Utils.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/SCF/SCF.h" 18 #include "mlir/Dialect/SCF/Transforms.h" 19 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 20 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 21 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 22 #include "mlir/Dialect/StandardOps/IR/Ops.h" 23 #include "mlir/Dialect/Vector/VectorOps.h" 24 #include "mlir/IR/Matchers.h" 25 #include "mlir/IR/TensorEncoding.h" 26 #include "llvm/ADT/SmallBitVector.h" 27 28 using namespace mlir; 29 using namespace mlir::sparse_tensor; 30 31 //===----------------------------------------------------------------------===// 32 // Declarations of data structures. 33 //===----------------------------------------------------------------------===// 34 35 namespace { 36 37 // Iteration graph sorting. 38 enum SortMask { kSparseOnly = 0x0, kIncludeDense = 0x1, kIncludeUndef = 0x2 }; 39 40 // Reduction kinds. 41 enum Reduction { kSum, kProduct, kAnd, kOr, kXor }; 42 43 // Code generation. 44 struct CodeGen { 45 CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops) 46 : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors), 47 pointers(numTensors, std::vector<Value>(numLoops)), 48 indices(numTensors, std::vector<Value>(numLoops)), 49 highs(numTensors, std::vector<Value>(numLoops)), 50 pidxs(numTensors, std::vector<Value>(numLoops)), 51 idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(), 52 curVecLength(1), curVecMask() {} 53 /// Sparsification options. 54 SparsificationOptions options; 55 /// Universal dense indices and upper bounds (by index). The loops array 56 /// is updated with the value of the universal dense index in the current 57 /// loop. The sizes array is set once with the inferred dimension sizes. 58 std::vector<Value> loops; 59 std::vector<Value> sizes; 60 /// Buffers for storing dense and sparse numerical values (by tensor). 61 /// This array is set once during bufferization of all tensors. 62 std::vector<Value> buffers; 63 /// Sparse storage schemes (1-D): pointers and indices (by tensor and index). 64 /// This array is set once during bufferization of all sparse tensors. 65 std::vector<std::vector<Value>> pointers; 66 std::vector<std::vector<Value>> indices; 67 /// Sparse iteration information (by tensor and index). These arrays 68 /// are updated to remain current within the current loop. 69 std::vector<std::vector<Value>> highs; 70 std::vector<std::vector<Value>> pidxs; 71 std::vector<std::vector<Value>> idxs; 72 /// Current reduction, updated during code generation. When indices of a 73 /// reduction are exhausted, all inner loops can "scalarize" the reduction. 74 // TODO: currently only done for (a chain of) innermost for-loops, where it 75 // is most effective; we could generalize to more outer and while-loops. 76 unsigned redExp; 77 Value redVal; 78 Reduction redKind; 79 // Current vector length and mask. 80 unsigned curVecLength; 81 Value curVecMask; 82 }; 83 84 } // namespace 85 86 //===----------------------------------------------------------------------===// 87 // Sparse compiler analysis methods. 88 //===----------------------------------------------------------------------===// 89 90 /// Helper method to apply dimension ordering permutation. 91 static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d) { 92 if (enc) { 93 auto order = enc.getDimOrdering(); 94 if (order) { 95 assert(order.isPermutation()); 96 return order.getDimPosition(d); 97 } 98 } 99 return d; 100 } 101 102 /// Helper method to translate dim level type to internal representation. 103 static Dim toDim(const SparseTensorEncodingAttr &enc, unsigned d) { 104 if (enc) { 105 SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d]; 106 if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed) 107 return Dim::kSparse; 108 if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton) 109 return Dim::kSingle; 110 } 111 return Dim::kDense; 112 } 113 114 /// Helper method to inspect affine expressions. Rejects cases where the 115 /// same index is used in more than one dimension of a tensor. Also rejects 116 /// affine expressions that are not a direct index for annotated tensors. 117 /// TODO: accept more affine cases for sparse tensors 118 static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a, Dim dim, 119 bool isDense) { 120 switch (a.getKind()) { 121 case AffineExprKind::DimId: { 122 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 123 if (!merger.isDim(tensor, idx, Dim::kUndef)) 124 return false; // used more than once 125 merger.setDim(tensor, idx, dim); 126 return true; 127 } 128 case AffineExprKind::Add: 129 case AffineExprKind::Mul: { 130 if (!isDense) 131 return false; 132 auto binOp = a.cast<AffineBinaryOpExpr>(); 133 return findAffine(merger, tensor, binOp.getLHS(), dim, isDense) && 134 findAffine(merger, tensor, binOp.getRHS(), dim, isDense); 135 } 136 case AffineExprKind::Constant: 137 return isDense; 138 default: 139 return false; 140 } 141 } 142 143 /// Helper method to inspect sparse encodings in the tensor types. 144 /// Fills the per-dimension sparsity information for all tensors. 145 /// Returns true if the sparse annotations and affine subscript 146 /// expressions of all tensors are admissable. Returns false if 147 /// no annotations are found or inadmissable constructs occur. 148 static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) { 149 bool annotated = false; 150 for (OpOperand *t : op.getInputAndOutputOperands()) { 151 auto map = op.getTiedIndexingMap(t); 152 auto enc = getSparseTensorEncoding(t->get().getType()); 153 if (enc) 154 annotated = true; 155 assert(map.getNumResults() == op.getRank(t)); 156 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 157 unsigned tensor = t->getOperandNumber(); 158 AffineExpr a = map.getResult(perm(enc, d)); 159 if (!findAffine(merger, tensor, a, toDim(enc, d), !enc)) 160 return false; // inadmissable affine expression 161 } 162 } 163 return annotated; 164 } 165 166 /// A DFS helper to compute a topological sort. Note that recursion is 167 /// bounded by the number of implicit loops, which is always small. 168 /// Returns false when a cycle is detected. 169 static bool topSortDFS(unsigned i, std::vector<unsigned> &visit, 170 std::vector<unsigned> &topSort, 171 std::vector<std::vector<bool>> &adjM) { 172 if (visit[i] != 0) 173 return visit[i] != 1; // 1 denotes cycle! 174 visit[i] = 1; 175 for (unsigned j = 0, e = visit.size(); j < e; j++) 176 if (adjM[i][j]) 177 if (!topSortDFS(j, visit, topSort, adjM)) 178 return false; 179 visit[i] = 2; 180 topSort.push_back(i); 181 return true; 182 } 183 184 /// Helper method to add all constraints from the indices in one affine 185 /// expression before all indices in the other affine expression. For 186 /// example i0+i1 < i2+i3+1 yields i0<i2, i0<i3, i1<i2, and i1<i3. 187 static void addAffineOrderings(std::vector<std::vector<bool>> &adjM, 188 AffineExpr a, AffineExpr b, unsigned fidx) { 189 switch (a.getKind()) { 190 case AffineExprKind::DimId: { 191 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 192 if (b) 193 addAffineOrderings(adjM, b, AffineExpr(), idx); 194 else 195 adjM[fidx][idx] = true; 196 break; 197 } 198 case AffineExprKind::Add: 199 case AffineExprKind::Mul: { 200 auto binOp = a.cast<AffineBinaryOpExpr>(); 201 addAffineOrderings(adjM, binOp.getLHS(), b, fidx); 202 addAffineOrderings(adjM, binOp.getRHS(), b, fidx); 203 break; 204 } 205 default: 206 break; 207 } 208 } 209 210 /// Computes a topologically sorted iteration graph for the linalg operation. 211 /// Ensures all tensors are visited in natural index order. This is essential 212 /// for sparse storage formats since these only support access along fixed 213 /// dimensions. Even for dense storage formats, however, the natural index 214 /// order yields innermost unit-stride access with better spatial locality. 215 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op, 216 std::vector<unsigned> &topSort, 217 unsigned mask) { 218 // Set up an n x n from/to adjacency matrix of the iteration graph 219 // for the implicit loop indices i_0 .. i_n-1. 220 unsigned n = op.getNumLoops(); 221 std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false)); 222 223 // Iterate over the indexing maps of every tensor in the tensor expression. 224 for (OpOperand *t : op.getInputAndOutputOperands()) { 225 auto map = op.getTiedIndexingMap(t); 226 auto enc = getSparseTensorEncoding(t->get().getType()); 227 assert(map.getNumDims() == n); 228 // Skip dense tensor constraints when not requested. 229 if (!(mask & SortMask::kIncludeDense) && !enc) 230 continue; 231 // Each tensor expression and optional dimension ordering (row-major 232 // by default) puts an ordering constraint on the loop indices. For 233 // example, the tensor expresion A_ijk forces the ordering i < j < k 234 // on the loop indices if no explicit dimension ordering is given. 235 for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) { 236 AffineExpr f = map.getResult(perm(enc, d - 1)); 237 AffineExpr t = map.getResult(perm(enc, d)); 238 addAffineOrderings(adjM, f, t, 0); 239 } 240 // Push unrelated loops into sparse iteration space, so these 241 // will be skipped more often. 242 if (mask & SortMask::kIncludeUndef) { 243 unsigned tensor = t->getOperandNumber(); 244 for (unsigned i = 0; i < n; i++) 245 if (merger.isDim(tensor, i, Dim::kSparse)) 246 for (unsigned j = 0; j < n; j++) 247 if (merger.isDim(tensor, j, Dim::kUndef)) 248 adjM[i][j] = true; 249 } 250 } 251 252 // Topologically sort the iteration graph to determine loop order. 253 // Report failure for a cyclic iteration graph. 254 topSort.clear(); 255 topSort.reserve(n); 256 std::vector<unsigned> visit(n, 0); 257 for (unsigned i = 0; i < n; i++) 258 if (visit[i] == 0) 259 if (!topSortDFS(i, visit, topSort, adjM)) 260 return false; // cycle! 261 std::reverse(std::begin(topSort), std::end(topSort)); 262 return true; 263 } 264 265 /// Returns true when the tensor expression is admissable for codegen. 266 /// Since all sparse input tensors are admissable, we just need to check 267 /// whether the output tensor in the tensor expression codegen is admissable. 268 static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op, 269 unsigned exp) { 270 OpOperand *lhs = op.getOutputOperand(0); 271 unsigned tensor = lhs->getOperandNumber(); 272 auto enc = getSparseTensorEncoding(lhs->get().getType()); 273 // An non-annotated output tensor is assumed dense, and becomes a random 274 // access n-dim memref. Admissable since insertions cannot occur. 275 if (!enc) 276 return true; 277 // An all-dense annotated "sparse" output tensor becomes a linearized random 278 // access 1-dim memref. Also admissable since insertions cannot occur. 279 bool allDense = true; 280 unsigned numLoops = op.iterator_types().getValue().size(); 281 for (unsigned i = 0; i < numLoops; i++) 282 if (merger.isDim(tensor, i, Dim::kSparse)) { 283 allDense = false; 284 break; 285 } 286 if (allDense) 287 return true; 288 // A tensor expression with a sparse output tensor that changes its values 289 // but not its nonzero structure, an operation called "simply dynamic" in 290 // [Bik96,Ch9], is also admissable without special codegen. 291 if (merger.isConjunction(tensor, exp)) 292 return true; 293 // Reject for now since this requires changes to the nonzero structure. 294 // TODO: implement "workspaces" [Kjolstad2019] 295 return false; 296 } 297 298 //===----------------------------------------------------------------------===// 299 // Sparse compiler synthesis methods. 300 //===----------------------------------------------------------------------===// 301 302 /// Maps reduction kind to name encoding. 303 static StringRef getReductionName(Reduction kind) { 304 switch (kind) { 305 case kSum: 306 return "add"; 307 case kProduct: 308 return "mul"; 309 case kAnd: 310 return "and"; 311 case kOr: 312 return "or"; 313 case kXor: 314 return "xor"; 315 } 316 llvm_unreachable("unknown reduction kind"); 317 } 318 319 /// Maps operation to reduction. 320 static Reduction getReduction(Kind kind) { 321 switch (kind) { 322 case Kind::kAddF: 323 case Kind::kAddI: 324 case Kind::kSubF: 325 case Kind::kSubI: 326 return kSum; 327 case Kind::kMulF: 328 case Kind::kMulI: 329 return kProduct; 330 case Kind::kAndI: 331 return kAnd; 332 case Kind::kOrI: 333 return kOr; 334 case Kind::kXorI: 335 return kXor; 336 default: 337 llvm_unreachable("unexpected reduction operator"); 338 } 339 } 340 341 /// Generates an initial value for a vector reductions, following the scheme 342 /// given in Chapter 5 of "The Software Vectorization Handbook", where the 343 /// initial scalar value is correctly embedded in the vector reduction value, 344 /// and a straightforward horizontal reduction will complete the operation. 345 static Value genReductionInit(PatternRewriter &rewriter, Location loc, 346 Reduction kind, VectorType vtp, Value r) { 347 switch (kind) { 348 case kSum: 349 case kXor: { 350 // Initialize reduction vector to: | 0 | .. | 0 | r | 351 Attribute zero = rewriter.getZeroAttr(vtp); 352 Value vec = rewriter.create<ConstantOp>(loc, vtp, zero); 353 return rewriter.create<vector::InsertElementOp>(loc, r, vec, 0); 354 } 355 case kProduct: { 356 // Initialize reduction vector to: | 1 | .. | 1 | r | 357 Type etp = vtp.getElementType(); 358 Attribute one; 359 if (etp.isa<FloatType>()) 360 one = rewriter.getFloatAttr(etp, 1.0); 361 else 362 one = rewriter.getIntegerAttr(etp, 1); 363 Value vec = 364 rewriter.create<ConstantOp>(loc, vtp, DenseElementsAttr::get(vtp, one)); 365 return rewriter.create<vector::InsertElementOp>(loc, r, vec, 0); 366 } 367 case kAnd: 368 case kOr: 369 // Initialize reduction vector to: | r | .. | r | r | 370 return rewriter.create<vector::BroadcastOp>(loc, vtp, r); 371 } 372 llvm_unreachable("unknown reduction kind"); 373 } 374 375 /// Maps sparse integer option to actual integral storage type. 376 static Type genIntType(PatternRewriter &rewriter, unsigned width) { 377 if (width == 0) 378 return rewriter.getIndexType(); 379 return rewriter.getIntegerType(width); 380 } 381 382 /// Detects in-place annotation on tensor argument. 383 static bool getInPlace(Value val) { 384 if (auto arg = val.dyn_cast<BlockArgument>()) 385 if (auto funcOp = dyn_cast<FuncOp>(arg.getOwner()->getParentOp())) 386 if (auto attr = funcOp.getArgAttrOfType<BoolAttr>( 387 arg.getArgNumber(), linalg::LinalgDialect::kInplaceableAttrName)) 388 return attr.getValue(); 389 return false; 390 } 391 392 /// Generates buffer for the output tensor. Note that all sparse kernels 393 /// assume that when all elements are written to (viz. x(i) = y(i) * z(i)), 394 /// the output buffer is already initialized to all zeroes and only nonzeroes 395 /// values are computed and written out. For updates (viz. x(i) += y(i) * z(i)), 396 /// only nonzeroes values are used for the updates and no assumption on the 397 /// original contents of the output buffer is necessary.. 398 static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter, 399 linalg::GenericOp op, MemRefType denseTp, 400 ArrayRef<Value> args) { 401 Location loc = op.getLoc(); 402 Value tensor = op.getOutputOperand(0)->get(); 403 // The output tensor simply could materialize from the buffer that will 404 // be generated for the tensor present in the outs() clause. This has 405 // the major advantage that the sparse kernel only updates the nonzero 406 // positions for the output tensor. 407 if (getInPlace(tensor)) 408 return rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 409 // By default, a new buffer is allocated which is initialized to the 410 // tensor defined in the outs() clause. This is always correct but 411 // introduces a dense initialization component that may negatively 412 // impact the running complexity of the sparse kernel. If the tensor 413 // materializes within this method, we need to preserve the zero 414 // initialization assumption of all sparse output buffers. 415 if (auto init = tensor.getDefiningOp<linalg::InitTensorOp>()) { 416 Type tp = denseTp.getElementType(); 417 Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args); 418 Value zero = rewriter.create<ConstantOp>(loc, tp, rewriter.getZeroAttr(tp)); 419 rewriter.create<linalg::FillOp>(loc, zero, alloc); 420 return alloc; 421 } 422 Value init = rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 423 Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args); 424 rewriter.create<memref::CopyOp>(loc, init, alloc); 425 return alloc; 426 } 427 428 /// Local bufferization of all dense and sparse data structures. 429 /// This code enables testing the first prototype sparse compiler. 430 // TODO: replace this with a proliferated bufferization strategy 431 static bool genBuffers(Merger &merger, CodeGen &codegen, 432 PatternRewriter &rewriter, linalg::GenericOp op) { 433 Location loc = op.getLoc(); 434 assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1); 435 // For every tensor, find lower and upper bound on dimensions, set the 436 // same bounds on loop indices, and obtain dense or sparse buffer(s). 437 SmallVector<Value, 4> args; 438 for (OpOperand *t : op.getInputAndOutputOperands()) { 439 unsigned tensor = t->getOperandNumber(); 440 auto shape = op.getShape(t); 441 auto map = op.getTiedIndexingMap(t); 442 auto enc = getSparseTensorEncoding(t->get().getType()); 443 // Scan all dimensions of current tensor. 444 args.clear(); 445 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 446 AffineExpr a = map.getResult(perm(enc, d)); 447 if (a.getKind() != AffineExprKind::DimId) 448 continue; // compound 449 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 450 // Handle sparse storage schemes. 451 if (merger.isDim(tensor, idx, Dim::kSparse)) { 452 auto dynShape = {ShapedType::kDynamicSize}; 453 auto ptrTp = MemRefType::get( 454 dynShape, genIntType(rewriter, enc.getPointerBitWidth())); 455 auto indTp = MemRefType::get( 456 dynShape, genIntType(rewriter, enc.getIndexBitWidth())); 457 Value dim = rewriter.create<ConstantIndexOp>(loc, d); 458 // Generate sparse primitives to obtains pointer and indices. 459 codegen.pointers[tensor][idx] = 460 rewriter.create<ToPointersOp>(loc, ptrTp, t->get(), dim); 461 codegen.indices[tensor][idx] = 462 rewriter.create<ToIndicesOp>(loc, indTp, t->get(), dim); 463 } 464 // Find upper bound in current dimension. 465 unsigned p = perm(enc, d); 466 Value up = linalg::createOrFoldDimOp(rewriter, loc, t->get(), p); 467 if (shape[p] == MemRefType::kDynamicSize) 468 args.push_back(up); 469 assert(codegen.highs[tensor][idx] == nullptr); 470 codegen.sizes[idx] = codegen.highs[tensor][idx] = up; 471 } 472 // Perform the required bufferization. Dense inputs materialize 473 // from the input tensors. Dense outputs need special handling. 474 // Sparse inputs use sparse primitives to obtain the values. 475 // We also accept in-place all-dense annotated "sparse" outputs. 476 Type elementType = getElementTypeOrSelf(t->get().getType()); 477 if (!enc) { 478 // Non-annotated dense tensors. 479 auto denseTp = MemRefType::get(shape, elementType); 480 if (tensor < op.getNumInputs()) 481 codegen.buffers[tensor] = 482 rewriter.create<memref::BufferCastOp>(loc, denseTp, t->get()); 483 else 484 codegen.buffers[tensor] = 485 genOutputBuffer(codegen, rewriter, op, denseTp, args); 486 } else { 487 // Annotated sparse tensors. 488 if (tensor == op.getNumInputs() && !getInPlace(t->get())) 489 return false; // reject output if not in-place 490 auto dynShape = {ShapedType::kDynamicSize}; 491 auto sparseTp = MemRefType::get(dynShape, elementType); 492 codegen.buffers[tensor] = 493 rewriter.create<ToValuesOp>(loc, sparseTp, t->get()); 494 } 495 } 496 return true; 497 } 498 499 /// Constructs vector type. 500 static VectorType vectorType(CodeGen &codegen, Type etp) { 501 return VectorType::get(codegen.curVecLength, etp); 502 } 503 504 /// Constructs vector type from pointer. 505 static VectorType vectorType(CodeGen &codegen, Value ptr) { 506 return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType()); 507 } 508 509 /// Constructs vector iteration mask. 510 static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter, 511 Value iv, Value lo, Value hi, Value step) { 512 Location loc = iv.getLoc(); 513 VectorType mtp = vectorType(codegen, rewriter.getIntegerType(1)); 514 // Special case if the vector length evenly divides the trip count (for 515 // example, "for i = 0, 128, 16"). A constant all-true mask is generated 516 // so that all subsequent masked memory operations are immediately folded 517 // into unconditional memory operations. 518 IntegerAttr loInt, hiInt, stepInt; 519 if (matchPattern(lo, m_Constant(&loInt)) && 520 matchPattern(hi, m_Constant(&hiInt)) && 521 matchPattern(step, m_Constant(&stepInt))) { 522 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) 523 return rewriter.create<vector::BroadcastOp>( 524 loc, mtp, rewriter.create<ConstantIntOp>(loc, 1, 1)); 525 } 526 // Otherwise, generate a vector mask that avoids overrunning the upperbound 527 // during vector execution. Here we rely on subsequent loop optimizations to 528 // avoid executing the mask in all iterations, for example, by splitting the 529 // loop into an unconditional vector loop and a scalar cleanup loop. 530 auto minMap = AffineMap::get( 531 /*dimCount=*/2, /*symbolCount=*/1, 532 {rewriter.getAffineSymbolExpr(0), 533 rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)}, 534 rewriter.getContext()); 535 Value end = 536 rewriter.createOrFold<AffineMinOp>(loc, minMap, ValueRange{hi, iv, step}); 537 return rewriter.create<vector::CreateMaskOp>(loc, mtp, end); 538 } 539 540 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi]. 541 static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter, 542 Value ptr, ArrayRef<Value> args) { 543 Location loc = ptr.getLoc(); 544 VectorType vtp = vectorType(codegen, ptr); 545 Value pass = rewriter.create<ConstantOp>(loc, vtp, rewriter.getZeroAttr(vtp)); 546 if (args.back().getType().isa<VectorType>()) { 547 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 548 Value indexVec = args.back(); 549 scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0); 550 return rewriter.create<vector::GatherOp>( 551 loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass); 552 } 553 return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args, 554 codegen.curVecMask, pass); 555 } 556 557 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs. 558 static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter, 559 Value rhs, Value ptr, ArrayRef<Value> args) { 560 Location loc = ptr.getLoc(); 561 if (args.back().getType().isa<VectorType>()) { 562 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 563 Value indexVec = args.back(); 564 scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0); 565 rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, 566 codegen.curVecMask, rhs); 567 return; 568 } 569 rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask, 570 rhs); 571 } 572 573 /// Generates a vectorized invariant. Here we rely on subsequent loop 574 /// optimizations to hoist the invariant broadcast out of the vector loop. 575 static Value genVectorInvariantValue(CodeGen &codegen, 576 PatternRewriter &rewriter, Value val) { 577 VectorType vtp = vectorType(codegen, val.getType()); 578 return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val); 579 } 580 581 /// Generates an affine expression. 582 // 583 // TODO: generalize for sparse tensor subscripts 584 // 585 static Value genAffine(CodeGen &codegen, PatternRewriter &rewriter, 586 AffineExpr a, Location loc) { 587 switch (a.getKind()) { 588 case AffineExprKind::DimId: { 589 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 590 return codegen.loops[idx]; // universal dense index 591 } 592 case AffineExprKind::Add: { 593 auto binOp = a.cast<AffineBinaryOpExpr>(); 594 return rewriter.create<AddIOp>( 595 loc, genAffine(codegen, rewriter, binOp.getLHS(), loc), 596 genAffine(codegen, rewriter, binOp.getRHS(), loc)); 597 } 598 case AffineExprKind::Mul: { 599 auto binOp = a.cast<AffineBinaryOpExpr>(); 600 return rewriter.create<MulIOp>( 601 loc, genAffine(codegen, rewriter, binOp.getLHS(), loc), 602 genAffine(codegen, rewriter, binOp.getRHS(), loc)); 603 } 604 case AffineExprKind::Constant: { 605 int64_t c = a.cast<AffineConstantExpr>().getValue(); 606 return rewriter.create<ConstantIndexOp>(loc, c); 607 } 608 default: 609 llvm_unreachable("unexpected affine subscript"); 610 } 611 } 612 613 /// Generates subscript for load/store on a dense or sparse tensor. 614 static Value genSubscript(CodeGen &codegen, PatternRewriter &rewriter, 615 linalg::GenericOp op, OpOperand *t, 616 SmallVector<Value, 4> &args) { 617 unsigned tensor = t->getOperandNumber(); 618 auto map = op.getTiedIndexingMap(t); 619 auto enc = getSparseTensorEncoding(t->get().getType()); 620 unsigned rank = map.getNumResults(); 621 if (enc) { 622 // Note that currently, all sparse subscripts are simple. 623 // TODO: accept affine too? 624 unsigned idx = map.getDimPosition(perm(enc, rank - 1)); 625 assert(codegen.pidxs[tensor][idx] != nullptr); 626 args.push_back(codegen.pidxs[tensor][idx]); // position index 627 } else { 628 for (unsigned d = 0; d < rank; d++) { 629 AffineExpr a = map.getResult(perm(enc, d)); 630 args.push_back(genAffine(codegen, rewriter, a, op.getLoc())); 631 } 632 } 633 return codegen.buffers[tensor]; 634 } 635 636 /// Generates a load on a dense or sparse tensor. 637 static Value genTensorLoad(Merger &merger, CodeGen &codegen, 638 PatternRewriter &rewriter, linalg::GenericOp op, 639 unsigned exp) { 640 // Test if the load was hoisted to a higher loop nest. 641 Value val = merger.exp(exp).val; 642 if (val) { 643 if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>()) 644 return genVectorInvariantValue(codegen, rewriter, val); 645 return val; 646 } 647 // Actual load. 648 SmallVector<Value, 4> args; 649 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 650 Value ptr = genSubscript(codegen, rewriter, op, t, args); 651 if (codegen.curVecLength > 1) 652 return genVectorLoad(codegen, rewriter, ptr, args); 653 return rewriter.create<memref::LoadOp>(op.getLoc(), ptr, args); 654 } 655 656 /// Generates a store on a dense or sparse tensor. 657 static void genTensorStore(Merger &merger, CodeGen &codegen, 658 PatternRewriter &rewriter, linalg::GenericOp op, 659 Value rhs) { 660 // Test if this is a scalarized reduction. 661 if (codegen.redVal) { 662 if (codegen.curVecLength > 1) 663 rhs = rewriter.create<SelectOp>(op.getLoc(), codegen.curVecMask, rhs, 664 codegen.redVal); 665 codegen.redVal = rhs; 666 return; 667 } 668 // Actual store. 669 SmallVector<Value, 4> args; 670 OpOperand *t = op.getOutputOperand(0); 671 Value ptr = genSubscript(codegen, rewriter, op, t, args); 672 if (codegen.curVecLength > 1) 673 genVectorStore(codegen, rewriter, rhs, ptr, args); 674 else 675 rewriter.create<memref::StoreOp>(op.getLoc(), rhs, ptr, args); 676 } 677 678 /// Generates a pointer/index load from the sparse storage scheme. Narrower 679 /// data types need to be zero extended before casting the value into the 680 /// index type used for looping and indexing. 681 static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc, 682 Value ptr, Value s) { 683 // See https://llvm.org/docs/GetElementPtr.html for some background on 684 // the complications described below. 685 if (codegen.curVecLength > 1) { 686 // Since the index vector is used in a subsequent gather/scatter operations, 687 // which effectively defines an unsigned pointer + signed index, we must 688 // zero extend the vector to an index width. For 8-bit and 16-bit values, 689 // an 32-bit index width suffices. For 32-bit values, zero extending the 690 // elements into 64-bit loses some performance since the 32-bit indexed 691 // gather/scatter is more efficient than the 64-bit index variant (if the 692 // negative 32-bit index space is unused, the enableSIMDIndex32 flag can 693 // preserve this performance). For 64-bit values, there is no good way 694 // to state that the indices are unsigned, with creates the potential of 695 // incorrect address calculations in the unlikely case we need such 696 // extremely large offsets. 697 Type etp = ptr.getType().cast<MemRefType>().getElementType(); 698 Value vload = genVectorLoad(codegen, rewriter, ptr, {s}); 699 if (!etp.isa<IndexType>()) { 700 if (etp.getIntOrFloatBitWidth() < 32) 701 vload = rewriter.create<ZeroExtendIOp>( 702 loc, vload, vectorType(codegen, rewriter.getIntegerType(32))); 703 else if (etp.getIntOrFloatBitWidth() < 64 && 704 !codegen.options.enableSIMDIndex32) 705 vload = rewriter.create<ZeroExtendIOp>( 706 loc, vload, vectorType(codegen, rewriter.getIntegerType(64))); 707 } 708 return vload; 709 } 710 // For the scalar case, we simply zero extend narrower indices into 64-bit 711 // values before casting to index without a performance penalty. Here too, 712 // however, indices that already are 64-bit, in theory, cannot express the 713 // full range as explained above. 714 Value load = rewriter.create<memref::LoadOp>(loc, ptr, s); 715 if (!load.getType().isa<IndexType>()) { 716 if (load.getType().getIntOrFloatBitWidth() < 64) 717 load = rewriter.create<ZeroExtendIOp>(loc, load, 718 rewriter.getIntegerType(64)); 719 load = rewriter.create<IndexCastOp>(loc, load, rewriter.getIndexType()); 720 } 721 return load; 722 } 723 724 /// Generates an invariant value. 725 static Value genInvariantValue(Merger &merger, CodeGen &codegen, 726 PatternRewriter &rewriter, unsigned exp) { 727 Value val = merger.exp(exp).val; 728 if (codegen.curVecLength > 1) 729 return genVectorInvariantValue(codegen, rewriter, val); 730 return val; 731 } 732 733 /// Generates an address computation "sz * p + i". 734 static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter, 735 Location loc, Value size, Value p, Value i) { 736 Value mul = rewriter.create<MulIOp>(loc, size, p); 737 if (auto vtp = i.getType().dyn_cast<VectorType>()) { 738 Value inv = rewriter.create<IndexCastOp>(loc, mul, vtp.getElementType()); 739 mul = genVectorInvariantValue(codegen, rewriter, inv); 740 } 741 return rewriter.create<AddIOp>(loc, mul, i); 742 } 743 744 /// Generates start of a reduction. 745 static Value genReductionStart(Merger &merger, CodeGen &codegen, 746 PatternRewriter &rewriter, 747 linalg::GenericOp op) { 748 if (codegen.redVal) 749 return codegen.redVal; // chained with previous for-loop 750 // Generate vector or scalar start of a reduction. 751 unsigned vl = codegen.curVecLength; 752 if (vl > 1) { 753 VectorType vtp = vectorType(codegen, codegen.buffers[codegen.redExp]); 754 assert(!merger.exp(codegen.redExp).val); 755 codegen.curVecLength = 1; 756 Value load = genTensorLoad(merger, codegen, rewriter, op, codegen.redExp); 757 codegen.curVecLength = vl; 758 return genReductionInit(rewriter, op.getLoc(), codegen.redKind, vtp, load); 759 } 760 return genTensorLoad(merger, codegen, rewriter, op, codegen.redExp); 761 } 762 763 /// Generates end of a reduction. 764 static void genReductionEnd(Merger &merger, CodeGen &codegen, 765 PatternRewriter &rewriter, linalg::GenericOp op) { 766 Value red = codegen.redVal; 767 if (!red) 768 return; 769 assert(codegen.curVecLength == 1); 770 codegen.redVal = merger.exp(codegen.redExp).val = Value(); // end chain 771 // Generate vector or scalar end of a reduction. 772 if (auto vtp = red.getType().dyn_cast<VectorType>()) { 773 StringRef name = getReductionName(codegen.redKind); 774 StringAttr kind = rewriter.getStringAttr(name); 775 red = rewriter.create<vector::ReductionOp>( 776 op.getLoc(), vtp.getElementType(), kind, red, ValueRange{}); 777 } 778 genTensorStore(merger, codegen, rewriter, op, red); 779 } 780 781 /// Recursively generates tensor expression. 782 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 783 linalg::GenericOp op, unsigned exp) { 784 Location loc = op.getLoc(); 785 if (exp == -1u) 786 return Value(); 787 if (merger.exp(exp).kind == Kind::kTensor) 788 return genTensorLoad(merger, codegen, rewriter, op, exp); 789 if (merger.exp(exp).kind == Kind::kInvariant) 790 return genInvariantValue(merger, codegen, rewriter, exp); 791 Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0); 792 Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1); 793 if (merger.exp(exp).kind == Kind::kNegI) { 794 // TODO: no negi in std, need to make zero explicit. 795 Type tp = op.getOutputTensorTypes()[0].getElementType(); 796 v1 = v0; 797 v0 = rewriter.create<ConstantOp>(loc, tp, rewriter.getZeroAttr(tp)); 798 if (codegen.curVecLength > 1) 799 v0 = genVectorInvariantValue(codegen, rewriter, v0); 800 } 801 return merger.buildExp(rewriter, loc, exp, v0, v1); 802 } 803 804 /// Determines if affine expression is invariant. 805 static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a, 806 unsigned ldx, bool &atLevel) { 807 switch (a.getKind()) { 808 case AffineExprKind::DimId: { 809 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 810 if (idx == ldx) 811 atLevel = true; 812 return codegen.loops[idx] != nullptr; // no longer in play? 813 } 814 case AffineExprKind::Add: 815 case AffineExprKind::Mul: { 816 auto binOp = a.cast<AffineBinaryOpExpr>(); 817 return isInvariantAffine(codegen, binOp.getLHS(), ldx, atLevel) && 818 isInvariantAffine(codegen, binOp.getRHS(), ldx, atLevel); 819 } 820 default: 821 return true; 822 } 823 } 824 825 /// Hoists loop invariant tensor loads for which indices have been exhausted. 826 static void genInvariants(Merger &merger, CodeGen &codegen, 827 PatternRewriter &rewriter, linalg::GenericOp op, 828 unsigned exp, unsigned ldx, bool hoist, 829 Kind last = Kind::kTensor) { 830 if (exp == -1u) 831 return; 832 if (merger.exp(exp).kind == Kind::kTensor) { 833 // Inspect tensor indices. 834 bool atLevel = ldx == -1u; 835 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 836 auto map = op.getTiedIndexingMap(t); 837 auto enc = getSparseTensorEncoding(t->get().getType()); 838 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 839 AffineExpr a = map.getResult(perm(enc, d)); 840 if (!isInvariantAffine(codegen, a, ldx, atLevel)) 841 return; // still in play 842 } 843 // All exhausted at this level (atLevel denotes exactly at this level). 844 OpOperand *lhs = op.getOutputOperand(0); 845 if (lhs == t) { 846 codegen.redExp = hoist ? exp : -1u; 847 codegen.redKind = getReduction(last); 848 } else if (atLevel) { 849 merger.exp(exp).val = 850 hoist ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value(); 851 } 852 } else if (merger.exp(exp).kind != Kind::kInvariant) { 853 // Traverse into the binary operations. Note that we only hoist 854 // tensor loads, since subsequent MLIR/LLVM passes know how to 855 // deal with all other kinds of derived loop invariants. 856 Kind last = merger.exp(exp).kind; 857 unsigned e0 = merger.exp(exp).children.e0; 858 unsigned e1 = merger.exp(exp).children.e1; 859 genInvariants(merger, codegen, rewriter, op, e0, ldx, hoist, last); 860 genInvariants(merger, codegen, rewriter, op, e1, ldx, hoist, last); 861 } 862 } 863 864 /// Generates initialization code for the subsequent loop sequence at 865 /// current index level. Returns true if the loop sequence needs to 866 /// maintain the universal index. 867 static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 868 linalg::GenericOp op, std::vector<unsigned> &topSort, 869 unsigned at, llvm::BitVector &inits) { 870 bool needsUniv = false; 871 Location loc = op.getLoc(); 872 unsigned idx = topSort[at]; 873 874 // Initialize sparse positions. 875 for (unsigned b = 0, be = inits.size(); b < be; b++) { 876 if (inits[b]) { 877 unsigned tensor = merger.tensor(b); 878 assert(idx == merger.index(b)); 879 if (merger.isDim(b, Dim::kSparse)) { 880 // Initialize sparse index. 881 unsigned pat = at; 882 for (; pat != 0; pat--) { 883 if (codegen.pidxs[tensor][topSort[pat - 1]]) 884 break; 885 } 886 Value ptr = codegen.pointers[tensor][idx]; 887 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 888 Value p0 = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0) 889 : codegen.pidxs[tensor][topSort[pat - 1]]; 890 codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0); 891 Value p1 = rewriter.create<AddIOp>(loc, p0, one); 892 codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1); 893 } else { 894 // Dense index still in play. 895 needsUniv = true; 896 } 897 } 898 } 899 900 // Initialize the universal dense index. 901 codegen.loops[idx] = rewriter.create<ConstantIndexOp>(loc, 0); 902 return needsUniv; 903 } 904 905 /// Returns vectorization strategy. Any implicit inner loop in the Linalg 906 /// operation is a candidate. Whether it is actually converted to SIMD code 907 /// depends on the requested strategy. 908 static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) { 909 switch (codegen.options.vectorizationStrategy) { 910 case SparseVectorizationStrategy::kNone: 911 return false; 912 case SparseVectorizationStrategy::kDenseInnerLoop: 913 return isInner && !isSparse; 914 case SparseVectorizationStrategy::kAnyStorageInnerLoop: 915 return isInner; 916 } 917 llvm_unreachable("unexpected vectorization strategy"); 918 } 919 920 /// Returns parallelization strategy. Any implicit loop in the Linalg operation 921 /// that is marked "parallel" is a candidate. Whether it is actually converted 922 /// to a parallel operation depends on the requested strategy. 923 static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction, 924 bool isSparse, bool isVector) { 925 switch (codegen.options.parallelizationStrategy) { 926 case SparseParallelizationStrategy::kNone: 927 return false; 928 case SparseParallelizationStrategy::kDenseOuterLoop: 929 return isOuter && !isSparse && !isReduction && !isVector; 930 case SparseParallelizationStrategy::kAnyStorageOuterLoop: 931 return isOuter && !isReduction && !isVector; 932 case SparseParallelizationStrategy::kDenseAnyLoop: 933 return !isSparse && !isReduction && !isVector; 934 case SparseParallelizationStrategy::kAnyStorageAnyLoop: 935 return !isReduction && !isVector; 936 } 937 llvm_unreachable("unexpected parallelization strategy"); 938 } 939 940 /// Checks unit strides for dense tensors. The iteration graph may have ignored 941 /// dense access patterns in order to avoid cycles (sparse access patterns are 942 /// always placed innermost), but that means dense access has become strided. 943 /// For now, we reject vectorization of such cases. 944 /// TODO: implement strided load/stores on dense arrays 945 static bool denseUnitStrides(Merger &merger, linalg::GenericOp op, 946 unsigned ldx) { 947 for (OpOperand *t : op.getInputAndOutputOperands()) { 948 if (!getSparseTensorEncoding(t->get().getType())) { 949 auto map = op.getTiedIndexingMap(t); 950 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 951 AffineExpr a = map.getResult(d); 952 if (a.getKind() != AffineExprKind::DimId) 953 return false; // very conservative 954 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 955 if (idx == ldx && d != rank - 1) 956 return false; 957 } 958 } 959 } 960 return true; 961 } 962 963 /// Generates a for-loop on a single index. 964 static Operation *genFor(Merger &merger, CodeGen &codegen, 965 PatternRewriter &rewriter, linalg::GenericOp op, 966 bool isOuter, bool isInner, unsigned idx, 967 llvm::BitVector &indices) { 968 unsigned fb = indices.find_first(); 969 unsigned tensor = merger.tensor(fb); 970 assert(idx == merger.index(fb)); 971 auto iteratorTypes = op.iterator_types().getValue(); 972 bool isReduction = isReductionIterator(iteratorTypes[idx]); 973 bool isSparse = merger.isDim(fb, Dim::kSparse); 974 bool isVector = isVectorFor(codegen, isInner, isSparse) && 975 denseUnitStrides(merger, op, idx); 976 bool isParallel = 977 isParallelFor(codegen, isOuter, isReduction, isSparse, isVector); 978 979 // Prepare vector length. 980 if (isVector) 981 codegen.curVecLength = codegen.options.vectorLength; 982 983 // Loop bounds and increment. 984 Location loc = op.getLoc(); 985 Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx]; 986 Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx]; 987 Value step = rewriter.create<ConstantIndexOp>(loc, codegen.curVecLength); 988 989 // Emit a parallel loop. 990 if (isParallel) { 991 assert(!isVector); 992 scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step); 993 if (isSparse) 994 codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0]; 995 else 996 codegen.loops[idx] = parOp.getInductionVars()[0]; 997 rewriter.setInsertionPointToStart(parOp.getBody()); 998 return parOp; 999 } 1000 1001 // Emit a sequential loop, potentially with a scalarized reduction. 1002 bool scalarRed = isInner && codegen.redExp != -1u; 1003 SmallVector<Value, 4> operands; 1004 if (scalarRed) { 1005 Value load = genReductionStart(merger, codegen, rewriter, op); 1006 operands.push_back(load); 1007 } 1008 scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands); 1009 if (scalarRed) { 1010 codegen.redVal = merger.exp(codegen.redExp).val = 1011 forOp.getRegionIterArgs().front(); 1012 } 1013 // Assign induction variable to sparse or dense index. 1014 Value iv = forOp.getInductionVar(); 1015 if (isSparse) 1016 codegen.pidxs[tensor][idx] = iv; 1017 else 1018 codegen.loops[idx] = iv; 1019 rewriter.setInsertionPointToStart(forOp.getBody()); 1020 // Share vector iteration mask between all subsequent loads/stores. 1021 if (isVector) 1022 codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step); 1023 return forOp; 1024 } 1025 1026 /// Emit a while-loop for co-iteration over multiple indices. 1027 static Operation *genWhile(Merger &merger, CodeGen &codegen, 1028 PatternRewriter &rewriter, linalg::GenericOp op, 1029 unsigned idx, bool needsUniv, 1030 llvm::BitVector &indices) { 1031 SmallVector<Type, 4> types; 1032 SmallVector<Value, 4> operands; 1033 // Construct the while-loop with a parameter for each index. 1034 Type indexType = rewriter.getIndexType(); 1035 for (unsigned b = 0, be = indices.size(); b < be; b++) { 1036 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1037 unsigned tensor = merger.tensor(b); 1038 assert(idx == merger.index(b)); 1039 types.push_back(indexType); 1040 assert(codegen.pidxs[tensor][idx].getType().isa<IndexType>() && 1041 "type mismatch for sparse index"); 1042 operands.push_back(codegen.pidxs[tensor][idx]); 1043 } 1044 } 1045 if (needsUniv) { 1046 types.push_back(indexType); 1047 assert(codegen.loops[idx].getType().isa<IndexType>() && 1048 "type mismatch for universal index"); 1049 operands.push_back(codegen.loops[idx]); 1050 } 1051 Location loc = op.getLoc(); 1052 scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands); 1053 Block *before = rewriter.createBlock(&whileOp.before(), {}, types); 1054 Block *after = rewriter.createBlock(&whileOp.after(), {}, types); 1055 1056 // Build the "before" region, which effectively consists 1057 // of a conjunction of "i < upper" tests on all induction. 1058 rewriter.setInsertionPointToStart(&whileOp.before().front()); 1059 Value cond; 1060 unsigned o = 0; 1061 for (unsigned b = 0, be = indices.size(); b < be; b++) { 1062 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1063 unsigned tensor = merger.tensor(b); 1064 assert(idx == merger.index(b)); 1065 Value op1 = before->getArgument(o); 1066 Value op2 = codegen.highs[tensor][idx]; 1067 Value opc = rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, op1, op2); 1068 cond = cond ? rewriter.create<AndOp>(loc, cond, opc) : opc; 1069 codegen.pidxs[tensor][idx] = after->getArgument(o++); 1070 } 1071 } 1072 if (needsUniv) 1073 codegen.loops[idx] = after->getArgument(o++); 1074 assert(o == operands.size()); 1075 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 1076 rewriter.setInsertionPointToStart(&whileOp.after().front()); 1077 return whileOp; 1078 } 1079 1080 /// Generates a for-loop or a while-loop, depending on whether it implements 1081 /// singleton iteration or co-iteration over the given conjunction. 1082 static Operation *genLoop(Merger &merger, CodeGen &codegen, 1083 PatternRewriter &rewriter, linalg::GenericOp op, 1084 std::vector<unsigned> &topSort, unsigned at, 1085 bool needsUniv, llvm::BitVector &indices) { 1086 unsigned idx = topSort[at]; 1087 if (indices.count() == 1) { 1088 bool isOuter = at == 0; 1089 bool isInner = at == topSort.size() - 1; 1090 return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx, 1091 indices); 1092 } 1093 genReductionEnd(merger, codegen, rewriter, op); // cannot chain 1094 return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices); 1095 } 1096 1097 /// Generates the local variables for this loop, consisting of the sparse 1098 /// indices, restored universal dense index, and dense positions. 1099 static void genLocals(Merger &merger, CodeGen &codegen, 1100 PatternRewriter &rewriter, linalg::GenericOp op, 1101 std::vector<unsigned> &topSort, unsigned at, 1102 bool needsUniv, llvm::BitVector &locals) { 1103 Location loc = op.getLoc(); 1104 unsigned idx = topSort[at]; 1105 1106 // Initialize sparse indices. 1107 Value min; 1108 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1109 if (locals[b] && merger.isDim(b, Dim::kSparse)) { 1110 unsigned tensor = merger.tensor(b); 1111 assert(idx == merger.index(b)); 1112 Value ptr = codegen.indices[tensor][idx]; 1113 Value s = codegen.pidxs[tensor][idx]; 1114 Value load = genLoad(codegen, rewriter, loc, ptr, s); 1115 codegen.idxs[tensor][idx] = load; 1116 if (!needsUniv) { 1117 if (min) { 1118 Value cmp = 1119 rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, load, min); 1120 min = rewriter.create<SelectOp>(loc, cmp, load, min); 1121 } else { 1122 min = load; 1123 } 1124 } 1125 } 1126 } 1127 1128 // Merge dense universal index over minimum. 1129 if (min) { 1130 assert(!needsUniv); 1131 codegen.loops[idx] = min; 1132 } 1133 1134 // Initialize dense positions. Note that we generate dense indices of the 1135 // output tensor unconditionally, since they may not appear in the lattice, 1136 // but may be needed for linearized codegen. 1137 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1138 if ((locals[b] || merger.isOutTensor(b, idx)) && 1139 merger.isDim(b, Dim::kDense)) { 1140 unsigned tensor = merger.tensor(b); 1141 assert(idx == merger.index(b)); 1142 unsigned pat = at; 1143 for (; pat != 0; pat--) 1144 if (codegen.pidxs[tensor][topSort[pat - 1]]) 1145 break; 1146 Value p = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0) 1147 : codegen.pidxs[tensor][topSort[pat - 1]]; 1148 codegen.pidxs[tensor][idx] = genAddress( 1149 codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]); 1150 } 1151 } 1152 } 1153 1154 /// Generates the induction structure for a while-loop. 1155 static void genWhileInduction(Merger &merger, CodeGen &codegen, 1156 PatternRewriter &rewriter, linalg::GenericOp op, 1157 unsigned idx, bool needsUniv, 1158 llvm::BitVector &induction, ResultRange results) { 1159 Location loc = op.getLoc(); 1160 unsigned o = 0; 1161 SmallVector<Value, 4> operands; 1162 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 1163 for (unsigned b = 0, be = induction.size(); b < be; b++) { 1164 if (induction[b] && merger.isDim(b, Dim::kSparse)) { 1165 unsigned tensor = merger.tensor(b); 1166 assert(idx == merger.index(b)); 1167 Value op1 = codegen.idxs[tensor][idx]; 1168 Value op2 = codegen.loops[idx]; 1169 Value op3 = codegen.pidxs[tensor][idx]; 1170 Value cmp = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2); 1171 Value add = rewriter.create<AddIOp>(loc, op3, one); 1172 operands.push_back(rewriter.create<SelectOp>(loc, cmp, add, op3)); 1173 codegen.pidxs[tensor][idx] = results[o++]; 1174 } 1175 } 1176 if (needsUniv) { 1177 operands.push_back(rewriter.create<AddIOp>(loc, codegen.loops[idx], one)); 1178 codegen.loops[idx] = results[o++]; 1179 } 1180 assert(o == operands.size()); 1181 rewriter.create<scf::YieldOp>(loc, operands); 1182 } 1183 1184 /// Generates a single if-statement within a while-loop. 1185 static scf::IfOp genIf(Merger &merger, CodeGen &codegen, 1186 PatternRewriter &rewriter, linalg::GenericOp op, 1187 unsigned idx, llvm::BitVector &conditions) { 1188 Location loc = op.getLoc(); 1189 Value cond; 1190 for (unsigned b = 0, be = conditions.size(); b < be; b++) { 1191 if (conditions[b]) { 1192 unsigned tensor = merger.tensor(b); 1193 assert(idx == merger.index(b)); 1194 Value clause; 1195 if (merger.isDim(b, Dim::kSparse)) { 1196 Value op1 = codegen.idxs[tensor][idx]; 1197 Value op2 = codegen.loops[idx]; 1198 clause = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2); 1199 } else { 1200 clause = rewriter.create<ConstantIntOp>(loc, 1, 1); // true 1201 } 1202 cond = cond ? rewriter.create<AndOp>(loc, cond, clause) : clause; 1203 } 1204 } 1205 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ true); 1206 rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); 1207 return ifOp; 1208 } 1209 1210 /// Recursively generates code while computing iteration lattices in order 1211 /// to manage the complexity of implementing co-iteration over unions 1212 /// and intersections of sparse iterations spaces. 1213 static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1214 linalg::GenericOp op, std::vector<unsigned> &topSort, 1215 unsigned exp, unsigned at) { 1216 // At each leaf, assign remaining tensor (sub)expression to output tensor. 1217 if (at == topSort.size()) { 1218 Value rhs = genExp(merger, codegen, rewriter, op, exp); 1219 genTensorStore(merger, codegen, rewriter, op, rhs); 1220 return; 1221 } 1222 assert(codegen.curVecLength == 1); 1223 1224 // Construct iteration lattices for current loop index, with L0 at top. 1225 // Then emit initialization code for the loop sequence at this level. 1226 // We maintain the universal dense index if dense indices are still 1227 // in play for a non-singleton loop sequence. 1228 Location loc = op.getLoc(); 1229 unsigned idx = topSort[at]; 1230 unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx)); 1231 unsigned lsize = merger.set(lts).size(); 1232 assert(lsize != 0); 1233 unsigned l0 = merger.set(lts)[0]; 1234 unsigned ldx = at == 0 ? -1u : topSort[at - 1]; 1235 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/true); 1236 bool needsUniv = false; 1237 if (genInit(merger, codegen, rewriter, op, topSort, at, 1238 merger.lat(l0).bits)) { 1239 // Maintain the universal index only if it is actually 1240 // consumed by a subsequent lattice point. 1241 for (unsigned i = 1; i < lsize; i++) { 1242 unsigned li = merger.set(lts)[i]; 1243 if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) { 1244 needsUniv = true; 1245 break; 1246 } 1247 } 1248 } 1249 1250 // Emit a loop for every lattice point L0 >= Li. 1251 for (unsigned i = 0; i < lsize; i++) { 1252 unsigned li = merger.set(lts)[i]; 1253 1254 // Emit loop. 1255 codegen.curVecLength = 1; 1256 llvm::BitVector indices = merger.lat(li).simple; 1257 Operation *loop = 1258 genLoop(merger, codegen, rewriter, op, topSort, at, needsUniv, indices); 1259 genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, 1260 merger.lat(li).bits); 1261 1262 // Visit all lattices points with Li >= Lj to generate the 1263 // loop-body, possibly with if statements for coiteration. 1264 bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr; 1265 for (unsigned j = 0; j < lsize; j++) { 1266 unsigned lj = merger.set(lts)[j]; 1267 unsigned ej = merger.lat(lj).exp; 1268 if (li == lj || merger.latGT(li, lj)) { 1269 // Recurse into body of each branch. 1270 if (isWhile) { 1271 scf::IfOp ifOp = 1272 genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple); 1273 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1274 rewriter.setInsertionPointToStart(&ifOp.elseRegion().front()); 1275 } else { 1276 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1277 } 1278 } 1279 } 1280 1281 // Wrap-up induction and restore insertion point. 1282 if (isWhile) { 1283 scf::WhileOp whileOp = cast<scf::WhileOp>(loop); 1284 rewriter.setInsertionPointToEnd(&whileOp.after().front()); 1285 genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv, 1286 merger.lat(li).bits, whileOp.results()); 1287 } else { 1288 needsUniv = false; 1289 if (codegen.redVal) { 1290 rewriter.create<scf::YieldOp>(loc, codegen.redVal); 1291 codegen.redVal = loop->getResult(0); 1292 } 1293 } 1294 rewriter.setInsertionPointAfter(loop); 1295 } 1296 1297 // Wrap-up loop sequence. 1298 codegen.curVecLength = 1; 1299 genReductionEnd(merger, codegen, rewriter, op); 1300 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/false); 1301 codegen.loops[idx] = Value(); 1302 } 1303 1304 /// Converts the result computed by the sparse kernel into the required form. 1305 static void genResult(Merger &merger, CodeGen &codegen, 1306 PatternRewriter &rewriter, linalg::GenericOp op) { 1307 Location loc = op.getLoc(); 1308 OpOperand *lhs = op.getOutputOperand(0); 1309 Type resType = lhs->get().getType(); 1310 unsigned tensor = lhs->getOperandNumber(); 1311 auto map = op.getTiedIndexingMap(lhs); 1312 auto enc = getSparseTensorEncoding(resType); 1313 Value result = codegen.buffers.back(); // value array 1314 if (enc) { 1315 // The sparse annotation unambigiously defines the arrays needed 1316 // to "reconstruct" the sparse tensor from the storage scheme 1317 // (even though lowering should never need this eventually). 1318 SmallVector<Value, 4> args; 1319 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 1320 AffineExpr a = map.getResult(perm(enc, d)); 1321 if (a.getKind() != AffineExprKind::DimId) 1322 continue; // compound 1323 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 1324 if (merger.isDim(tensor, idx, Dim::kSparse)) { 1325 args.push_back(codegen.pointers[tensor][idx]); 1326 args.push_back(codegen.indices[tensor][idx]); 1327 } 1328 } 1329 args.push_back(result); 1330 result = rewriter.create<ToTensorOp>(loc, resType, args); 1331 } else { 1332 // To "reconstruct" an non-annotated tensor, sipmly load it 1333 // from the bufferized value. 1334 result = rewriter.create<memref::TensorLoadOp>(loc, resType, result); 1335 } 1336 rewriter.replaceOp(op, result); 1337 } 1338 1339 //===----------------------------------------------------------------------===// 1340 // Sparse compiler rewriting methods. 1341 //===----------------------------------------------------------------------===// 1342 1343 namespace { 1344 1345 /// Sparse rewriting rule for generic Lingalg operation. 1346 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> { 1347 public: 1348 GenericOpSparsifier(MLIRContext *context, SparsificationOptions o) 1349 : OpRewritePattern<linalg::GenericOp>(context), options(o) {} 1350 1351 LogicalResult matchAndRewrite(linalg::GenericOp op, 1352 PatternRewriter &rewriter) const override { 1353 // Detects sparse annotations and translate the per-dimension sparsity 1354 // information for all tensors to loop indices in the kernel. 1355 assert(op.getNumOutputs() == 1); 1356 unsigned numTensors = op.getNumInputsAndOutputs(); 1357 unsigned numLoops = op.iterator_types().getValue().size(); 1358 Merger merger(numTensors, numLoops); 1359 if (!findSparseAnnotations(merger, op)) 1360 return failure(); 1361 1362 // Computes a topologically sorted iteration graph to ensure 1363 // tensors are visited in natural index order. Fails on cycles. 1364 // This assumes that higher-level passes have already put the 1365 // tensors in each tensor expression in a feasible order. 1366 std::vector<unsigned> topSort; 1367 if (!computeIterationGraph(merger, op, topSort, 1368 SortMask::kIncludeUndef | 1369 SortMask::kIncludeDense) && 1370 !computeIterationGraph(merger, op, topSort, SortMask::kIncludeUndef) && 1371 !computeIterationGraph(merger, op, topSort, SortMask::kIncludeDense) && 1372 !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly)) 1373 return failure(); 1374 1375 // Builds the tensor expression for the Linalg operation in SSA form. 1376 Optional<unsigned> exp = merger.buildTensorExpFromLinalg(op); 1377 if (!exp.hasValue()) 1378 return failure(); 1379 1380 // Rejects an inadmissable tensor expression. 1381 if (!isAdmissableTensorExp(merger, op, exp.getValue())) 1382 return failure(); 1383 1384 // Recursively generates code. 1385 CodeGen codegen(options, numTensors, numLoops); 1386 if (!genBuffers(merger, codegen, rewriter, op)) 1387 return failure(); // could not bufferize 1388 genStmt(merger, codegen, rewriter, op, topSort, exp.getValue(), 0); 1389 genResult(merger, codegen, rewriter, op); 1390 return success(); 1391 } 1392 1393 private: 1394 /// Options to control sparse code generation. 1395 SparsificationOptions options; 1396 }; 1397 1398 } // namespace 1399 1400 /// Populates the given patterns list with rewriting rules required for 1401 /// the sparsification of linear algebra operations. 1402 void mlir::populateSparsificationPatterns( 1403 RewritePatternSet &patterns, const SparsificationOptions &options) { 1404 patterns.add<GenericOpSparsifier>(patterns.getContext(), options); 1405 } 1406