1 //===- Sparsification.cpp - Implementation of sparsification --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements converting sparse tensor types to actual sparse code. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 15 #include "mlir/Dialect/Linalg/Utils/Utils.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/SCF/SCF.h" 18 #include "mlir/Dialect/SCF/Transforms.h" 19 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 20 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 21 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 22 #include "mlir/Dialect/StandardOps/IR/Ops.h" 23 #include "mlir/Dialect/Vector/VectorOps.h" 24 #include "mlir/IR/Matchers.h" 25 #include "mlir/IR/TensorEncoding.h" 26 #include "llvm/ADT/SmallBitVector.h" 27 28 using namespace mlir; 29 using namespace mlir::sparse_tensor; 30 31 namespace { 32 33 // Iteration graph sorting. 34 enum SortMask { kSparseOnly = 0x0, kIncludeDense = 0x1, kIncludeUndef = 0x2 }; 35 36 // Code generation. 37 struct CodeGen { 38 CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops) 39 : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors), 40 pointers(numTensors, std::vector<Value>(numLoops)), 41 indices(numTensors, std::vector<Value>(numLoops)), 42 highs(numTensors, std::vector<Value>(numLoops)), 43 pidxs(numTensors, std::vector<Value>(numLoops)), 44 idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(), 45 curVecLength(1), curVecMask() {} 46 /// Sparsification options. 47 SparsificationOptions options; 48 /// Universal dense indices and upper bounds (by index). The loops array 49 /// is updated with the value of the universal dense index in the current 50 /// loop. The sizes array is set once with the inferred dimension sizes. 51 std::vector<Value> loops; 52 std::vector<Value> sizes; 53 /// Buffers for storing dense and sparse numerical values (by tensor). 54 /// This array is set once during bufferization of all tensors. 55 std::vector<Value> buffers; 56 /// Sparse storage schemes (1-D): pointers and indices (by tensor and index). 57 /// This array is set once during bufferization of all sparse tensors. 58 std::vector<std::vector<Value>> pointers; 59 std::vector<std::vector<Value>> indices; 60 /// Sparse iteration information (by tensor and index). These arrays 61 /// are updated to remain current within the current loop. 62 std::vector<std::vector<Value>> highs; 63 std::vector<std::vector<Value>> pidxs; 64 std::vector<std::vector<Value>> idxs; 65 /// Current reduction, updated during code generation. When indices of a 66 /// reduction are exhausted, all inner loops can "scalarize" the reduction. 67 // TODO: currently only done for (a chain of) innermost for-loops, where it 68 // is most effective; we could generalize to more outer and while-loops. 69 unsigned redExp; 70 Value redVal; 71 // Current vector length and mask. 72 unsigned curVecLength; 73 Value curVecMask; 74 }; 75 76 } // namespace 77 78 // Helper method to apply dimension ordering permutation. 79 static unsigned perm(SparseTensorEncodingAttr &enc, unsigned d) { 80 if (enc) { 81 auto order = enc.getDimOrdering(); 82 if (order) { 83 assert(order.isPermutation()); 84 return order.getDimPosition(d); 85 } 86 } 87 return d; 88 } 89 90 // Helper method to translate dim level type to internal representation. 91 static Dim toDim(SparseTensorEncodingAttr &enc, unsigned d) { 92 if (enc) { 93 SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d]; 94 if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed) 95 return Dim::kSparse; 96 if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton) 97 return Dim::kSingle; 98 } 99 return Dim::kDense; 100 } 101 102 /// Helper method to inspect sparse encodings in the tensor types. 103 /// Fills the per-dimension sparsity information for all tensors. 104 static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) { 105 bool annotated = false; 106 for (OpOperand *t : op.getInputAndOutputOperands()) { 107 auto map = op.getTiedIndexingMap(t); 108 if (!map.isProjectedPermutation()) 109 return false; 110 auto enc = getSparseTensorEncoding(t->get().getType()); 111 if (enc) 112 annotated = true; 113 assert(map.getNumResults() == op.getRank(t)); 114 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 115 unsigned idx = map.getDimPosition(perm(enc, d)); 116 merger.setDim(t->getOperandNumber(), idx, toDim(enc, d)); 117 } 118 } 119 return annotated; 120 } 121 122 /// A DFS helper to compute a topological sort. Note that recursion is 123 /// bounded by the number of implicit loops, which is always small. 124 /// Returns false when a cycle is detected. 125 static bool topSortDFS(unsigned i, std::vector<unsigned> &visit, 126 std::vector<unsigned> &topSort, 127 std::vector<std::vector<bool>> &adjM) { 128 if (visit[i] != 0) 129 return visit[i] != 1; // 1 denotes cycle! 130 visit[i] = 1; 131 for (unsigned j = 0, e = visit.size(); j < e; j++) 132 if (adjM[i][j]) 133 if (!topSortDFS(j, visit, topSort, adjM)) 134 return false; 135 visit[i] = 2; 136 topSort.push_back(i); 137 return true; 138 } 139 140 /// Computes a topologically sorted iteration graph for the linalg operation. 141 /// Ensures all tensors are visited in natural index order. This is essential 142 /// for sparse storage formats since these only support access along fixed 143 /// dimensions. Even for dense storage formats, however, the natural index 144 /// order yields innermost unit-stride access with better spatial locality. 145 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op, 146 std::vector<unsigned> &topSort, 147 unsigned mask) { 148 // Set up an n x n from/to adjacency matrix of the iteration graph 149 // for the implicit loop indices i_0 .. i_n-1. 150 unsigned n = op.getNumLoops(); 151 std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false)); 152 153 // Iterate over the indexing maps of every tensor in the tensor expression. 154 for (OpOperand *t : op.getInputAndOutputOperands()) { 155 auto map = op.getTiedIndexingMap(t); 156 auto enc = getSparseTensorEncoding(t->get().getType()); 157 assert(map.getNumDims() == n); 158 // Skip dense tensor constraints when not requested. 159 if (!(mask & SortMask::kIncludeDense) && !enc) 160 continue; 161 // Each tensor expression and optional dimension ordering (row-major 162 // by default) puts an ordering constraint on the loop indices. For 163 // example, the tensor expresion A_ijk forces the ordering i < j < k 164 // on the loop indices if no explicit dimension ordering is given. 165 for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) { 166 unsigned f = map.getDimPosition(perm(enc, d - 1)); 167 unsigned t = map.getDimPosition(perm(enc, d)); 168 adjM[f][t] = true; 169 } 170 // Push unrelated loops into sparse iteration space, so these 171 // will be skipped more often. 172 if (mask & SortMask::kIncludeUndef) { 173 unsigned tensor = t->getOperandNumber(); 174 for (unsigned i = 0; i < n; i++) 175 if (merger.isDim(tensor, i, Dim::kSparse)) 176 for (unsigned j = 0; j < n; j++) 177 if (merger.isDim(tensor, j, Dim::kUndef)) 178 adjM[i][j] = true; 179 } 180 } 181 182 // Topologically sort the iteration graph to determine loop order. 183 // Report failure for a cyclic iteration graph. 184 topSort.clear(); 185 topSort.reserve(n); 186 std::vector<unsigned> visit(n, 0); 187 for (unsigned i = 0; i < n; i++) 188 if (visit[i] == 0) 189 if (!topSortDFS(i, visit, topSort, adjM)) 190 return false; // cycle! 191 std::reverse(std::begin(topSort), std::end(topSort)); 192 return true; 193 } 194 195 /// Returns true when the tensor expression is admissable for codegen. 196 /// Since all sparse input tensors are admissable, we just need to check 197 /// whether the output tensor in the tensor expression codegen is admissable. 198 static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op, 199 unsigned exp) { 200 OpOperand *lhs = op.getOutputOperand(0); 201 unsigned tensor = lhs->getOperandNumber(); 202 auto enc = getSparseTensorEncoding(lhs->get().getType()); 203 // An non-annotated output tensor is assumed dense, and becomes a random 204 // access n-dim memref. Admissable since inserstions cannot occur. 205 if (!enc) 206 return true; 207 // An all-dense annotated "sparse" output tensor becomes a linearized random 208 // access 1-dim memref. Also admissable since insertions cannot occur. 209 bool allDense = true; 210 unsigned numLoops = op.iterator_types().getValue().size(); 211 for (unsigned i = 0; i < numLoops; i++) 212 if (merger.isDim(tensor, i, Dim::kSparse)) { 213 allDense = false; 214 break; 215 } 216 if (allDense) 217 return true; 218 // A tensor expression with a sparse output tensor that changes its values 219 // but not its nonzero structure, an operation called "simply dynamic" in 220 // [Bik96,Ch9], is also admissable without special codegen. 221 if (merger.isConjunction(tensor, exp)) 222 return true; 223 // Reject for now since this requires changes to the nonzero structure. 224 // TODO: implement "workspaces" [Kjolstad2019] 225 return false; 226 } 227 228 /// Maps sparse integer option to actual integral storage type. 229 static Type genIntType(PatternRewriter &rewriter, unsigned width) { 230 if (width == 0) 231 return rewriter.getIndexType(); 232 return rewriter.getIntegerType(width); 233 } 234 235 /// Detects in-place annotation on tensor argument. 236 static bool getInPlace(Value val) { 237 if (auto arg = val.dyn_cast<BlockArgument>()) 238 if (auto funcOp = dyn_cast<FuncOp>(arg.getOwner()->getParentOp())) 239 if (auto attr = funcOp.getArgAttrOfType<BoolAttr>( 240 arg.getArgNumber(), linalg::LinalgDialect::kInplaceableAttrName)) 241 return attr.getValue(); 242 return false; 243 } 244 245 /// Generates buffer for the output tensor. 246 static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter, 247 linalg::GenericOp op, MemRefType denseTp, 248 ArrayRef<Value> args) { 249 Location loc = op.getLoc(); 250 Value tensor = op.getOutputOperand(0)->get(); 251 // The output tensor simply could materialize from the buffer that will 252 // be generated for the tensor present in the outs() clause. This has 253 // the major advantage that the sparse kernel only updates the nonzero 254 // positions for the output tensor. 255 if (getInPlace(tensor)) 256 return rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 257 // By default, a new buffer is allocated which is initialized to the 258 // tensor defined in the outs() clause. This is always correct but 259 // introduces a dense initialization component that may negatively 260 // impact the running complexity of the sparse kernel. 261 Value init = rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 262 Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args); 263 rewriter.create<memref::CopyOp>(loc, init, alloc); 264 return alloc; 265 } 266 267 /// Local bufferization of all dense and sparse data structures. 268 /// This code enables testing the first prototype sparse compiler. 269 // TODO: replace this with a proliferated bufferization strategy 270 static bool genBuffers(Merger &merger, CodeGen &codegen, 271 PatternRewriter &rewriter, linalg::GenericOp op) { 272 Location loc = op.getLoc(); 273 assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1); 274 // For every tensor, find lower and upper bound on dimensions, set the 275 // same bounds on loop indices, and obtain dense or sparse buffer(s). 276 SmallVector<Value, 4> args; 277 for (OpOperand *t : op.getInputAndOutputOperands()) { 278 unsigned tensor = t->getOperandNumber(); 279 auto shape = op.getShape(t); 280 auto map = op.getTiedIndexingMap(t); 281 auto enc = getSparseTensorEncoding(t->get().getType()); 282 // Scan all dimensions of current tensor. 283 args.clear(); 284 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 285 unsigned idx = map.getDimPosition(perm(enc, d)); 286 // Handle sparse storage schemes. 287 if (merger.isDim(tensor, idx, Dim::kSparse)) { 288 auto dynShape = {ShapedType::kDynamicSize}; 289 auto ptrTp = MemRefType::get( 290 dynShape, genIntType(rewriter, enc.getPointerBitWidth())); 291 auto indTp = MemRefType::get( 292 dynShape, genIntType(rewriter, enc.getIndexBitWidth())); 293 Value dim = rewriter.create<ConstantIndexOp>(loc, d); 294 // Generate sparse primitives to obtains pointer and indices. 295 codegen.pointers[tensor][idx] = 296 rewriter.create<ToPointersOp>(loc, ptrTp, t->get(), dim); 297 codegen.indices[tensor][idx] = 298 rewriter.create<ToIndicesOp>(loc, indTp, t->get(), dim); 299 } 300 // Find upper bound in current dimension. 301 unsigned p = perm(enc, d); 302 Value up = linalg::createOrFoldDimOp(rewriter, loc, t->get(), p); 303 if (shape[p] == MemRefType::kDynamicSize) 304 args.push_back(up); 305 assert(codegen.highs[tensor][idx] == nullptr); 306 codegen.sizes[idx] = codegen.highs[tensor][idx] = up; 307 } 308 // Perform the required bufferization. Dense inputs materialize 309 // from the input tensors. Dense outputs need special handling. 310 // Sparse inputs use sparse primitives to obtain the values. 311 // We also accept in-place all-dense annotated "sparse" outputs. 312 Type elementType = getElementTypeOrSelf(t->get().getType()); 313 if (!enc) { 314 // Non-annotated dense tensors. 315 auto denseTp = MemRefType::get(shape, elementType); 316 if (tensor < op.getNumInputs()) 317 codegen.buffers[tensor] = 318 rewriter.create<memref::BufferCastOp>(loc, denseTp, t->get()); 319 else 320 codegen.buffers[tensor] = 321 genOutputBuffer(codegen, rewriter, op, denseTp, args); 322 } else { 323 // Annotated sparse tensors. 324 if (tensor == op.getNumInputs() && !getInPlace(t->get())) 325 return false; // reject output if not in-place 326 auto dynShape = {ShapedType::kDynamicSize}; 327 auto sparseTp = MemRefType::get(dynShape, elementType); 328 codegen.buffers[tensor] = 329 rewriter.create<ToValuesOp>(loc, sparseTp, t->get()); 330 } 331 } 332 return true; 333 } 334 335 /// Constructs vector type. 336 static VectorType vectorType(CodeGen &codegen, Type etp) { 337 return VectorType::get(codegen.curVecLength, etp); 338 } 339 340 /// Constructs vector type from pointer. 341 static VectorType vectorType(CodeGen &codegen, Value ptr) { 342 return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType()); 343 } 344 345 /// Constructs vector iteration mask. 346 static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter, 347 Value iv, Value lo, Value hi, Value step) { 348 Location loc = iv.getLoc(); 349 VectorType mtp = vectorType(codegen, rewriter.getIntegerType(1)); 350 // Special case if the vector length evenly divides the trip count (for 351 // example, "for i = 0, 128, 16"). A constant all-true mask is generated 352 // so that all subsequent masked memory operations are immediately folded 353 // into unconditional memory operations. 354 IntegerAttr loInt, hiInt, stepInt; 355 if (matchPattern(lo, m_Constant(&loInt)) && 356 matchPattern(hi, m_Constant(&hiInt)) && 357 matchPattern(step, m_Constant(&stepInt))) { 358 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) 359 return rewriter.create<vector::BroadcastOp>( 360 loc, mtp, rewriter.create<ConstantIntOp>(loc, 1, 1)); 361 } 362 // Otherwise, generate a vector mask that avoids overrunning the upperbound 363 // during vector execution. Here we rely on subsequent loop optimizations to 364 // avoid executing the mask in all iterations, for example, by splitting the 365 // loop into an unconditional vector loop and a scalar cleanup loop. 366 auto minMap = AffineMap::get( 367 /*dimCount=*/2, /*symbolCount=*/1, 368 {rewriter.getAffineSymbolExpr(0), 369 rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)}, 370 rewriter.getContext()); 371 Value end = 372 rewriter.createOrFold<AffineMinOp>(loc, minMap, ValueRange{hi, iv, step}); 373 return rewriter.create<vector::CreateMaskOp>(loc, mtp, end); 374 } 375 376 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi]. 377 static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter, 378 Value ptr, ArrayRef<Value> args) { 379 Location loc = ptr.getLoc(); 380 VectorType vtp = vectorType(codegen, ptr); 381 Value pass = rewriter.create<ConstantOp>(loc, vtp, rewriter.getZeroAttr(vtp)); 382 if (args.back().getType().isa<VectorType>()) { 383 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 384 Value indexVec = args.back(); 385 scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0); 386 return rewriter.create<vector::GatherOp>( 387 loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass); 388 } 389 return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args, 390 codegen.curVecMask, pass); 391 } 392 393 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs. 394 static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter, 395 Value rhs, Value ptr, ArrayRef<Value> args) { 396 Location loc = ptr.getLoc(); 397 if (args.back().getType().isa<VectorType>()) { 398 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 399 Value indexVec = args.back(); 400 scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0); 401 rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, 402 codegen.curVecMask, rhs); 403 return; 404 } 405 rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask, 406 rhs); 407 } 408 409 /// Generates a vectorized invariant. Here we rely on subsequent loop 410 /// optimizations to hoist the invariant broadcast out of the vector loop. 411 static Value genVectorInvariantValue(CodeGen &codegen, 412 PatternRewriter &rewriter, Value val) { 413 VectorType vtp = vectorType(codegen, val.getType()); 414 return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val); 415 } 416 417 /// Generates a load on a dense or sparse tensor. 418 static Value genTensorLoad(Merger &merger, CodeGen &codegen, 419 PatternRewriter &rewriter, linalg::GenericOp op, 420 unsigned exp) { 421 // Test if the load was hoisted to a higher loop nest. 422 Value val = merger.exp(exp).val; 423 if (val) { 424 if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>()) 425 return genVectorInvariantValue(codegen, rewriter, val); 426 return val; 427 } 428 // Actual load. 429 SmallVector<Value, 4> args; 430 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 431 unsigned tensor = t->getOperandNumber(); 432 auto map = op.getTiedIndexingMap(t); 433 auto enc = getSparseTensorEncoding(t->get().getType()); 434 unsigned rank = map.getNumResults(); 435 if (enc) { 436 unsigned idx = map.getDimPosition(perm(enc, rank - 1)); 437 assert(codegen.pidxs[tensor][idx] != nullptr); 438 args.push_back(codegen.pidxs[tensor][idx]); // position index 439 } else { 440 for (unsigned d = 0; d < rank; d++) { 441 unsigned idx = map.getDimPosition(d); 442 args.push_back(codegen.loops[idx]); // universal dense index 443 } 444 } 445 Location loc = op.getLoc(); 446 Value ptr = codegen.buffers[tensor]; 447 if (codegen.curVecLength > 1) 448 return genVectorLoad(codegen, rewriter, ptr, args); 449 return rewriter.create<memref::LoadOp>(loc, ptr, args); 450 } 451 452 /// Generates a store on a dense or sparse tensor. 453 static void genTensorStore(Merger &merger, CodeGen &codegen, 454 PatternRewriter &rewriter, linalg::GenericOp op, 455 OpOperand *t, Value rhs) { 456 Location loc = op.getLoc(); 457 // Test if this is a scalarized reduction. 458 OpOperand *lhs = op.getOutputOperand(0); 459 if (lhs == t && codegen.redVal) { 460 if (codegen.curVecLength > 1) 461 rhs = rewriter.create<SelectOp>(loc, codegen.curVecMask, rhs, 462 codegen.redVal); 463 codegen.redVal = rhs; 464 return; 465 } 466 // Actual store. 467 SmallVector<Value, 4> args; 468 unsigned tensor = t->getOperandNumber(); 469 auto map = op.getTiedIndexingMap(t); 470 auto enc = getSparseTensorEncoding(t->get().getType()); 471 unsigned rank = map.getNumResults(); 472 if (enc) { 473 unsigned idx = map.getDimPosition(perm(enc, rank - 1)); 474 assert(codegen.pidxs[tensor][idx] != nullptr); 475 args.push_back(codegen.pidxs[tensor][idx]); // position index 476 } else { 477 for (unsigned d = 0; d < rank; d++) { 478 unsigned idx = map.getDimPosition(d); 479 args.push_back(codegen.loops[idx]); // universal dense index 480 } 481 } 482 Value ptr = codegen.buffers[tensor]; 483 if (codegen.curVecLength > 1) 484 genVectorStore(codegen, rewriter, rhs, ptr, args); 485 else 486 rewriter.create<memref::StoreOp>(loc, rhs, ptr, args); 487 } 488 489 /// Generates a pointer/index load from the sparse storage scheme. Narrower 490 /// data types need to be zero extended before casting the value into the 491 /// index type used for looping and indexing. 492 static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc, 493 Value ptr, Value s) { 494 // See https://llvm.org/docs/GetElementPtr.html for some background on 495 // the complications described below. 496 if (codegen.curVecLength > 1) { 497 // Since the index vector is used in a subsequent gather/scatter operations, 498 // which effectively defines an unsigned pointer + signed index, we must 499 // zero extend the vector to an index width. For 8-bit and 16-bit values, 500 // an 32-bit index width suffices. For 32-bit values, zero extending the 501 // elements into 64-bit loses some performance since the 32-bit indexed 502 // gather/scatter is more efficient than the 64-bit index variant (if the 503 // negative 32-bit index space is unused, the enableSIMDIndex32 flag can 504 // preserve this performance). For 64-bit values, there is no good way 505 // to state that the indices are unsigned, with creates the potential of 506 // incorrect address calculations in the unlikely case we need such 507 // extremely large offsets. 508 Type etp = ptr.getType().cast<MemRefType>().getElementType(); 509 Value vload = genVectorLoad(codegen, rewriter, ptr, {s}); 510 if (!etp.isa<IndexType>()) { 511 if (etp.getIntOrFloatBitWidth() < 32) 512 vload = rewriter.create<ZeroExtendIOp>( 513 loc, vload, vectorType(codegen, rewriter.getIntegerType(32))); 514 else if (etp.getIntOrFloatBitWidth() < 64 && 515 !codegen.options.enableSIMDIndex32) 516 vload = rewriter.create<ZeroExtendIOp>( 517 loc, vload, vectorType(codegen, rewriter.getIntegerType(64))); 518 } 519 return vload; 520 } 521 // For the scalar case, we simply zero extend narrower indices into 64-bit 522 // values before casting to index without a performance penalty. Here too, 523 // however, indices that already are 64-bit, in theory, cannot express the 524 // full range as explained above. 525 Value load = rewriter.create<memref::LoadOp>(loc, ptr, s); 526 if (!load.getType().isa<IndexType>()) { 527 if (load.getType().getIntOrFloatBitWidth() < 64) 528 load = rewriter.create<ZeroExtendIOp>(loc, load, 529 rewriter.getIntegerType(64)); 530 load = rewriter.create<IndexCastOp>(loc, load, rewriter.getIndexType()); 531 } 532 return load; 533 } 534 535 /// Generates an invariant value. 536 static Value genInvariantValue(Merger &merger, CodeGen &codegen, 537 PatternRewriter &rewriter, unsigned exp) { 538 Value val = merger.exp(exp).val; 539 if (codegen.curVecLength > 1) 540 return genVectorInvariantValue(codegen, rewriter, val); 541 return val; 542 } 543 544 /// Generates an address computation "sz * p + i". 545 static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter, 546 Location loc, Value size, Value p, Value i) { 547 Value mul = rewriter.create<MulIOp>(loc, size, p); 548 if (auto vtp = i.getType().dyn_cast<VectorType>()) { 549 Value inv = rewriter.create<IndexCastOp>(loc, mul, vtp.getElementType()); 550 mul = genVectorInvariantValue(codegen, rewriter, inv); 551 } 552 return rewriter.create<AddIOp>(loc, mul, i); 553 } 554 555 /// Generates start of a reduction. 556 static Value genReductionStart(Merger &merger, CodeGen &codegen, 557 PatternRewriter &rewriter, 558 linalg::GenericOp op) { 559 if (codegen.redVal) 560 return codegen.redVal; // chained with previous for-loop 561 if (codegen.curVecLength > 1) { 562 // TODO: assumes + reductions for now 563 VectorType vtp = vectorType(codegen, codegen.buffers[codegen.redExp]); 564 return rewriter.create<ConstantOp>(op.getLoc(), vtp, 565 rewriter.getZeroAttr(vtp)); 566 } 567 return genTensorLoad(merger, codegen, rewriter, op, codegen.redExp); 568 } 569 570 /// Generates end of a reduction. 571 static void genReductionEnd(Merger &merger, CodeGen &codegen, 572 PatternRewriter &rewriter, linalg::GenericOp op) { 573 Value red = codegen.redVal; 574 if (!red) 575 return; 576 assert(codegen.curVecLength == 1); 577 codegen.redVal = merger.exp(codegen.redExp).val = Value(); // end chain 578 OpOperand *lhs = op.getOutputOperand(0); 579 if (auto vtp = red.getType().dyn_cast<VectorType>()) { 580 // TODO: assumes + reductions for now 581 StringAttr kind = rewriter.getStringAttr("add"); 582 Value ld = genTensorLoad(merger, codegen, rewriter, op, codegen.redExp); 583 // Integer reductions don't accept an accumulator. 584 if (vtp.getElementType().isa<IntegerType>()) { 585 red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(), 586 kind, red, ValueRange{}); 587 red = rewriter.create<AddIOp>(op.getLoc(), red, ld); 588 } else { 589 red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(), 590 kind, red, ld); 591 } 592 } 593 genTensorStore(merger, codegen, rewriter, op, lhs, red); 594 } 595 596 /// Recursively generates tensor expression. 597 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 598 linalg::GenericOp op, unsigned exp) { 599 Location loc = op.getLoc(); 600 if (exp == -1u) 601 return Value(); 602 if (merger.exp(exp).kind == Kind::kTensor) 603 return genTensorLoad(merger, codegen, rewriter, op, exp); 604 if (merger.exp(exp).kind == Kind::kInvariant) 605 return genInvariantValue(merger, codegen, rewriter, exp); 606 Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0); 607 Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1); 608 if (merger.exp(exp).kind == Kind::kNegI) { 609 // TODO: no negi in std, need to make zero explicit. 610 Type tp = op.getOutputTensorTypes()[0].getElementType(); 611 v1 = v0; 612 v0 = rewriter.create<ConstantOp>(loc, tp, rewriter.getZeroAttr(tp)); 613 if (codegen.curVecLength > 1) 614 v0 = genVectorInvariantValue(codegen, rewriter, v0); 615 } 616 return merger.buildExp(rewriter, loc, exp, v0, v1); 617 } 618 619 /// Hoists loop invariant tensor loads for which indices have been exhausted. 620 static void genInvariants(Merger &merger, CodeGen &codegen, 621 PatternRewriter &rewriter, linalg::GenericOp op, 622 unsigned exp, unsigned ldx, bool hoist) { 623 if (exp == -1u) 624 return; 625 if (merger.exp(exp).kind == Kind::kTensor) { 626 // Inspect tensor indices. 627 bool atLevel = ldx == -1u; 628 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 629 auto map = op.getTiedIndexingMap(t); 630 auto enc = getSparseTensorEncoding(t->get().getType()); 631 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 632 unsigned idx = map.getDimPosition(perm(enc, d)); 633 if (!codegen.loops[idx]) 634 return; // still in play 635 else if (idx == ldx) 636 atLevel = true; 637 } 638 // All exhausted at this level (atLevel denotes exactly at this level). 639 OpOperand *lhs = op.getOutputOperand(0); 640 if (lhs == t) { 641 codegen.redExp = hoist ? exp : -1u; 642 } else if (atLevel) { 643 merger.exp(exp).val = 644 hoist ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value(); 645 } 646 } else if (merger.exp(exp).kind != Kind::kInvariant) { 647 // Traverse into the binary operations. Note that we only hoist 648 // tensor loads, since subsequent MLIR/LLVM passes know how to 649 // deal with all other kinds of derived loop invariants. 650 unsigned e0 = merger.exp(exp).children.e0; 651 unsigned e1 = merger.exp(exp).children.e1; 652 genInvariants(merger, codegen, rewriter, op, e0, ldx, hoist); 653 genInvariants(merger, codegen, rewriter, op, e1, ldx, hoist); 654 } 655 } 656 657 /// Generates initialization code for the subsequent loop sequence at 658 /// current index level. Returns true if the loop sequence needs to 659 /// maintain the universal index. 660 static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 661 linalg::GenericOp op, std::vector<unsigned> &topSort, 662 unsigned at, llvm::BitVector &inits) { 663 bool needsUniv = false; 664 Location loc = op.getLoc(); 665 unsigned idx = topSort[at]; 666 667 // Initialize sparse positions. 668 for (unsigned b = 0, be = inits.size(); b < be; b++) { 669 if (inits[b]) { 670 unsigned tensor = merger.tensor(b); 671 assert(idx == merger.index(b)); 672 if (merger.isDim(b, Dim::kSparse)) { 673 // Initialize sparse index. 674 unsigned pat = at; 675 for (; pat != 0; pat--) { 676 if (codegen.pidxs[tensor][topSort[pat - 1]]) 677 break; 678 } 679 Value ptr = codegen.pointers[tensor][idx]; 680 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 681 Value p0 = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0) 682 : codegen.pidxs[tensor][topSort[pat - 1]]; 683 codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0); 684 Value p1 = rewriter.create<AddIOp>(loc, p0, one); 685 codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1); 686 } else { 687 // Dense index still in play. 688 needsUniv = true; 689 } 690 } 691 } 692 693 // Initialize the universal dense index. 694 codegen.loops[idx] = rewriter.create<ConstantIndexOp>(loc, 0); 695 return needsUniv; 696 } 697 698 /// Returns vectorization strategy. Any implicit inner loop in the Linalg 699 /// operation is a candidate. Whether it is actually converted to SIMD code 700 /// depends on the requested strategy. 701 static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) { 702 switch (codegen.options.vectorizationStrategy) { 703 case SparseVectorizationStrategy::kNone: 704 return false; 705 case SparseVectorizationStrategy::kDenseInnerLoop: 706 return isInner && !isSparse; 707 case SparseVectorizationStrategy::kAnyStorageInnerLoop: 708 return isInner; 709 } 710 llvm_unreachable("unexpected vectorization strategy"); 711 } 712 713 /// Returns parallelization strategy. Any implicit loop in the Linalg operation 714 /// that is marked "parallel" is a candidate. Whether it is actually converted 715 /// to a parallel operation depends on the requested strategy. 716 static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction, 717 bool isSparse, bool isVector) { 718 switch (codegen.options.parallelizationStrategy) { 719 case SparseParallelizationStrategy::kNone: 720 return false; 721 case SparseParallelizationStrategy::kDenseOuterLoop: 722 return isOuter && !isSparse && !isReduction && !isVector; 723 case SparseParallelizationStrategy::kAnyStorageOuterLoop: 724 return isOuter && !isReduction && !isVector; 725 case SparseParallelizationStrategy::kDenseAnyLoop: 726 return !isSparse && !isReduction && !isVector; 727 case SparseParallelizationStrategy::kAnyStorageAnyLoop: 728 return !isReduction && !isVector; 729 } 730 llvm_unreachable("unexpected parallelization strategy"); 731 } 732 733 /// Checks unit strides for dense tensors. The iteration graph may have ignored 734 /// dense access patterns in order to avoid cycles (sparse access patterns are 735 /// always placed innermost), but that means dense access has become strided. 736 /// For now, we reject vectorization of such cases. 737 /// TODO: implement strided load/stores on dense arrays 738 static bool denseUnitStrides(Merger &merger, linalg::GenericOp op, 739 unsigned idx) { 740 for (OpOperand *t : op.getInputAndOutputOperands()) { 741 if (!getSparseTensorEncoding(t->get().getType())) { 742 auto map = op.getTiedIndexingMap(t); 743 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 744 if (map.getDimPosition(d) == idx && d != rank - 1) 745 return false; 746 } 747 } 748 } 749 return true; 750 } 751 752 /// Generates a for-loop on a single index. 753 static Operation *genFor(Merger &merger, CodeGen &codegen, 754 PatternRewriter &rewriter, linalg::GenericOp op, 755 bool isOuter, bool isInner, unsigned idx, 756 llvm::BitVector &indices) { 757 unsigned fb = indices.find_first(); 758 unsigned tensor = merger.tensor(fb); 759 assert(idx == merger.index(fb)); 760 auto iteratorTypes = op.iterator_types().getValue(); 761 bool isReduction = isReductionIterator(iteratorTypes[idx]); 762 bool isSparse = merger.isDim(fb, Dim::kSparse); 763 bool isVector = isVectorFor(codegen, isInner, isSparse) && 764 denseUnitStrides(merger, op, idx); 765 bool isParallel = 766 isParallelFor(codegen, isOuter, isReduction, isSparse, isVector); 767 768 // Prepare vector length. 769 if (isVector) 770 codegen.curVecLength = codegen.options.vectorLength; 771 772 // Loop bounds and increment. 773 Location loc = op.getLoc(); 774 Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx]; 775 Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx]; 776 Value step = rewriter.create<ConstantIndexOp>(loc, codegen.curVecLength); 777 778 // Emit a parallel loop. 779 if (isParallel) { 780 assert(!isVector); 781 scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step); 782 if (isSparse) 783 codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0]; 784 else 785 codegen.loops[idx] = parOp.getInductionVars()[0]; 786 rewriter.setInsertionPointToStart(parOp.getBody()); 787 return parOp; 788 } 789 790 // Emit a sequential loop, potentially with a scalarized reduction. 791 bool scalarRed = isInner && codegen.redExp != -1u; 792 SmallVector<Value, 4> operands; 793 if (scalarRed) { 794 Value load = genReductionStart(merger, codegen, rewriter, op); 795 operands.push_back(load); 796 } 797 scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands); 798 if (scalarRed) { 799 codegen.redVal = merger.exp(codegen.redExp).val = 800 forOp.getRegionIterArgs().front(); 801 } 802 // Assign induction variable to sparse or dense index. 803 Value iv = forOp.getInductionVar(); 804 if (isSparse) 805 codegen.pidxs[tensor][idx] = iv; 806 else 807 codegen.loops[idx] = iv; 808 rewriter.setInsertionPointToStart(forOp.getBody()); 809 // Share vector iteration mask between all subsequent loads/stores. 810 if (isVector) 811 codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step); 812 return forOp; 813 } 814 815 /// Emit a while-loop for co-iteration over multiple indices. 816 static Operation *genWhile(Merger &merger, CodeGen &codegen, 817 PatternRewriter &rewriter, linalg::GenericOp op, 818 unsigned idx, bool needsUniv, 819 llvm::BitVector &indices) { 820 SmallVector<Type, 4> types; 821 SmallVector<Value, 4> operands; 822 // Construct the while-loop with a parameter for each index. 823 Type indexType = rewriter.getIndexType(); 824 for (unsigned b = 0, be = indices.size(); b < be; b++) { 825 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 826 unsigned tensor = merger.tensor(b); 827 assert(idx == merger.index(b)); 828 types.push_back(indexType); 829 assert(codegen.pidxs[tensor][idx].getType().isa<IndexType>() && 830 "type mismatch for sparse index"); 831 operands.push_back(codegen.pidxs[tensor][idx]); 832 } 833 } 834 if (needsUniv) { 835 types.push_back(indexType); 836 assert(codegen.loops[idx].getType().isa<IndexType>() && 837 "type mismatch for universal index"); 838 operands.push_back(codegen.loops[idx]); 839 } 840 Location loc = op.getLoc(); 841 scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands); 842 Block *before = rewriter.createBlock(&whileOp.before(), {}, types); 843 Block *after = rewriter.createBlock(&whileOp.after(), {}, types); 844 845 // Build the "before" region, which effectively consists 846 // of a conjunction of "i < upper" tests on all induction. 847 rewriter.setInsertionPointToStart(&whileOp.before().front()); 848 Value cond; 849 unsigned o = 0; 850 for (unsigned b = 0, be = indices.size(); b < be; b++) { 851 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 852 unsigned tensor = merger.tensor(b); 853 assert(idx == merger.index(b)); 854 Value op1 = before->getArgument(o); 855 Value op2 = codegen.highs[tensor][idx]; 856 Value opc = rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, op1, op2); 857 cond = cond ? rewriter.create<AndOp>(loc, cond, opc) : opc; 858 codegen.pidxs[tensor][idx] = after->getArgument(o++); 859 } 860 } 861 if (needsUniv) 862 codegen.loops[idx] = after->getArgument(o++); 863 assert(o == operands.size()); 864 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 865 rewriter.setInsertionPointToStart(&whileOp.after().front()); 866 return whileOp; 867 } 868 869 /// Generates a for-loop or a while-loop, depending on whether it implements 870 /// singleton iteration or co-iteration over the given conjunction. 871 static Operation *genLoop(Merger &merger, CodeGen &codegen, 872 PatternRewriter &rewriter, linalg::GenericOp op, 873 std::vector<unsigned> &topSort, unsigned at, 874 bool needsUniv, llvm::BitVector &indices) { 875 unsigned idx = topSort[at]; 876 if (indices.count() == 1) { 877 bool isOuter = at == 0; 878 bool isInner = at == topSort.size() - 1; 879 return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx, 880 indices); 881 } 882 genReductionEnd(merger, codegen, rewriter, op); // cannot chain 883 return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices); 884 } 885 886 /// Generates the local variables for this loop, consisting of the sparse 887 /// indices, restored universal dense index, and dense positions. 888 static void genLocals(Merger &merger, CodeGen &codegen, 889 PatternRewriter &rewriter, linalg::GenericOp op, 890 std::vector<unsigned> &topSort, unsigned at, 891 bool needsUniv, llvm::BitVector &locals) { 892 Location loc = op.getLoc(); 893 unsigned idx = topSort[at]; 894 895 // Initialize sparse indices. 896 Value min; 897 for (unsigned b = 0, be = locals.size(); b < be; b++) { 898 if (locals[b] && merger.isDim(b, Dim::kSparse)) { 899 unsigned tensor = merger.tensor(b); 900 assert(idx == merger.index(b)); 901 Value ptr = codegen.indices[tensor][idx]; 902 Value s = codegen.pidxs[tensor][idx]; 903 Value load = genLoad(codegen, rewriter, loc, ptr, s); 904 codegen.idxs[tensor][idx] = load; 905 if (!needsUniv) { 906 if (min) { 907 Value cmp = 908 rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, load, min); 909 min = rewriter.create<SelectOp>(loc, cmp, load, min); 910 } else { 911 min = load; 912 } 913 } 914 } 915 } 916 917 // Merge dense universal index over minimum. 918 if (min) { 919 assert(!needsUniv); 920 codegen.loops[idx] = min; 921 } 922 923 // Initialize dense positions. Note that we generate dense indices of the 924 // output tensor unconditionally, since they may not appear in the lattice, 925 // but may be needed for linearized codegen. 926 for (unsigned b = 0, be = locals.size(); b < be; b++) { 927 if ((locals[b] || merger.isOutTensor(b, idx)) && 928 merger.isDim(b, Dim::kDense)) { 929 unsigned tensor = merger.tensor(b); 930 assert(idx == merger.index(b)); 931 unsigned pat = at; 932 for (; pat != 0; pat--) 933 if (codegen.pidxs[tensor][topSort[pat - 1]]) 934 break; 935 Value p = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0) 936 : codegen.pidxs[tensor][topSort[pat - 1]]; 937 codegen.pidxs[tensor][idx] = genAddress( 938 codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]); 939 } 940 } 941 } 942 943 /// Generates the induction structure for a while-loop. 944 static void genWhileInduction(Merger &merger, CodeGen &codegen, 945 PatternRewriter &rewriter, linalg::GenericOp op, 946 unsigned idx, bool needsUniv, 947 llvm::BitVector &induction, ResultRange results) { 948 Location loc = op.getLoc(); 949 unsigned o = 0; 950 SmallVector<Value, 4> operands; 951 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 952 for (unsigned b = 0, be = induction.size(); b < be; b++) { 953 if (induction[b] && merger.isDim(b, Dim::kSparse)) { 954 unsigned tensor = merger.tensor(b); 955 assert(idx == merger.index(b)); 956 Value op1 = codegen.idxs[tensor][idx]; 957 Value op2 = codegen.loops[idx]; 958 Value op3 = codegen.pidxs[tensor][idx]; 959 Value cmp = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2); 960 Value add = rewriter.create<AddIOp>(loc, op3, one); 961 operands.push_back(rewriter.create<SelectOp>(loc, cmp, add, op3)); 962 codegen.pidxs[tensor][idx] = results[o++]; 963 } 964 } 965 if (needsUniv) { 966 operands.push_back(rewriter.create<AddIOp>(loc, codegen.loops[idx], one)); 967 codegen.loops[idx] = results[o++]; 968 } 969 assert(o == operands.size()); 970 rewriter.create<scf::YieldOp>(loc, operands); 971 } 972 973 /// Generates a single if-statement within a while-loop. 974 static scf::IfOp genIf(Merger &merger, CodeGen &codegen, 975 PatternRewriter &rewriter, linalg::GenericOp op, 976 unsigned idx, llvm::BitVector &conditions) { 977 Location loc = op.getLoc(); 978 Value cond; 979 for (unsigned b = 0, be = conditions.size(); b < be; b++) { 980 if (conditions[b]) { 981 unsigned tensor = merger.tensor(b); 982 assert(idx == merger.index(b)); 983 Value clause; 984 if (merger.isDim(b, Dim::kSparse)) { 985 Value op1 = codegen.idxs[tensor][idx]; 986 Value op2 = codegen.loops[idx]; 987 clause = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2); 988 } else { 989 clause = rewriter.create<ConstantIntOp>(loc, 1, 1); // true 990 } 991 cond = cond ? rewriter.create<AndOp>(loc, cond, clause) : clause; 992 } 993 } 994 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ true); 995 rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); 996 return ifOp; 997 } 998 999 /// Recursively generates code while computing iteration lattices in order 1000 /// to manage the complexity of implementing co-iteration over unions 1001 /// and intersections of sparse iterations spaces. 1002 static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1003 linalg::GenericOp op, std::vector<unsigned> &topSort, 1004 unsigned exp, unsigned at) { 1005 // At each leaf, assign remaining tensor (sub)expression to output tensor. 1006 if (at == topSort.size()) { 1007 OpOperand *lhs = op.getOutputOperand(0); 1008 Value rhs = genExp(merger, codegen, rewriter, op, exp); 1009 genTensorStore(merger, codegen, rewriter, op, lhs, rhs); 1010 return; 1011 } 1012 assert(codegen.curVecLength == 1); 1013 1014 // Construct iteration lattices for current loop index, with L0 at top. 1015 // Then emit initialization code for the loop sequence at this level. 1016 // We maintain the universal dense index if dense indices are still 1017 // in play for a non-singleton loop sequence. 1018 Location loc = op.getLoc(); 1019 unsigned idx = topSort[at]; 1020 unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx)); 1021 unsigned lsize = merger.set(lts).size(); 1022 assert(lsize != 0); 1023 unsigned l0 = merger.set(lts)[0]; 1024 unsigned ldx = at == 0 ? -1u : topSort[at - 1]; 1025 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/true); 1026 bool needsUniv = false; 1027 if (genInit(merger, codegen, rewriter, op, topSort, at, 1028 merger.lat(l0).bits)) { 1029 // Maintain the universal index only if it is actually 1030 // consumed by a subsequent lattice point. 1031 for (unsigned i = 1; i < lsize; i++) { 1032 unsigned li = merger.set(lts)[i]; 1033 if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) { 1034 needsUniv = true; 1035 break; 1036 } 1037 } 1038 } 1039 1040 // Emit a loop for every lattice point L0 >= Li. 1041 for (unsigned i = 0; i < lsize; i++) { 1042 unsigned li = merger.set(lts)[i]; 1043 1044 // Emit loop. 1045 codegen.curVecLength = 1; 1046 llvm::BitVector indices = merger.lat(li).simple; 1047 Operation *loop = 1048 genLoop(merger, codegen, rewriter, op, topSort, at, needsUniv, indices); 1049 genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, 1050 merger.lat(li).bits); 1051 1052 // Visit all lattices points with Li >= Lj to generate the 1053 // loop-body, possibly with if statements for coiteration. 1054 bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr; 1055 for (unsigned j = 0; j < lsize; j++) { 1056 unsigned lj = merger.set(lts)[j]; 1057 unsigned ej = merger.lat(lj).exp; 1058 if (li == lj || merger.latGT(li, lj)) { 1059 // Recurse into body of each branch. 1060 if (isWhile) { 1061 scf::IfOp ifOp = 1062 genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple); 1063 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1064 rewriter.setInsertionPointToStart(&ifOp.elseRegion().front()); 1065 } else { 1066 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1067 } 1068 } 1069 } 1070 1071 // Wrap-up induction and restore insertion point. 1072 if (isWhile) { 1073 scf::WhileOp whileOp = cast<scf::WhileOp>(loop); 1074 rewriter.setInsertionPointToEnd(&whileOp.after().front()); 1075 genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv, 1076 merger.lat(li).bits, whileOp.results()); 1077 } else { 1078 needsUniv = false; 1079 if (codegen.redVal) { 1080 rewriter.create<scf::YieldOp>(loc, codegen.redVal); 1081 codegen.redVal = loop->getResult(0); 1082 } 1083 } 1084 rewriter.setInsertionPointAfter(loop); 1085 } 1086 1087 // Wrap-up loop sequence. 1088 codegen.curVecLength = 1; 1089 genReductionEnd(merger, codegen, rewriter, op); 1090 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/false); 1091 codegen.loops[idx] = Value(); 1092 } 1093 1094 /// Converts the result computed by the sparse kernel into the required form. 1095 static void genResult(Merger &merger, CodeGen &codegen, 1096 PatternRewriter &rewriter, linalg::GenericOp op) { 1097 Location loc = op.getLoc(); 1098 OpOperand *lhs = op.getOutputOperand(0); 1099 Type resType = lhs->get().getType(); 1100 unsigned tensor = lhs->getOperandNumber(); 1101 auto map = op.getTiedIndexingMap(lhs); 1102 auto enc = getSparseTensorEncoding(resType); 1103 Value result = codegen.buffers.back(); // value array 1104 if (enc) { 1105 // The sparse annotation unambigiously defines the arrays needed 1106 // to "reconstruct" the sparse tensor from the storage scheme 1107 // (even though lowering should never need this eventually). 1108 SmallVector<Value, 4> args; 1109 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 1110 unsigned idx = map.getDimPosition(perm(enc, d)); 1111 if (merger.isDim(tensor, idx, Dim::kSparse)) { 1112 args.push_back(codegen.pointers[tensor][idx]); 1113 args.push_back(codegen.indices[tensor][idx]); 1114 } 1115 } 1116 args.push_back(result); 1117 result = rewriter.create<ToTensorOp>(loc, resType, args); 1118 } else { 1119 // To "reconstruct" an non-annotated tensor, sipmly load it 1120 // from the bufferized value. 1121 result = rewriter.create<memref::TensorLoadOp>(loc, resType, result); 1122 } 1123 rewriter.replaceOp(op, result); 1124 } 1125 1126 namespace { 1127 1128 /// Sparse rewriting rule for generic Lingalg operation. 1129 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> { 1130 public: 1131 GenericOpSparsifier(MLIRContext *context, SparsificationOptions o) 1132 : OpRewritePattern<linalg::GenericOp>(context), options(o) {} 1133 1134 LogicalResult matchAndRewrite(linalg::GenericOp op, 1135 PatternRewriter &rewriter) const override { 1136 // Detects sparse annotations and translate the per-dimension sparsity 1137 // information for all tensors to loop indices in the kernel. 1138 assert(op.getNumOutputs() == 1); 1139 unsigned numTensors = op.getNumInputsAndOutputs(); 1140 unsigned numLoops = op.iterator_types().getValue().size(); 1141 Merger merger(numTensors, numLoops); 1142 if (!findSparseAnnotations(merger, op)) 1143 return failure(); 1144 1145 // Computes a topologically sorted iteration graph to ensure 1146 // tensors are visited in natural index order. Fails on cycles. 1147 // This assumes that higher-level passes have already put the 1148 // tensors in each tensor expression in a feasible order. 1149 std::vector<unsigned> topSort; 1150 if (!computeIterationGraph(merger, op, topSort, 1151 SortMask::kIncludeUndef | 1152 SortMask::kIncludeDense) && 1153 !computeIterationGraph(merger, op, topSort, SortMask::kIncludeUndef) && 1154 !computeIterationGraph(merger, op, topSort, SortMask::kIncludeDense) && 1155 !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly)) 1156 return failure(); 1157 1158 // Builds the tensor expression for the Linalg operation in SSA form. 1159 Optional<unsigned> exp = merger.buildTensorExpFromLinalg(op); 1160 if (!exp.hasValue()) 1161 return failure(); 1162 1163 // Rejects an inadmissable tensor expression. 1164 if (!isAdmissableTensorExp(merger, op, exp.getValue())) 1165 return failure(); 1166 1167 // Recursively generates code. 1168 CodeGen codegen(options, numTensors, numLoops); 1169 if (!genBuffers(merger, codegen, rewriter, op)) 1170 return failure(); // could not bufferize 1171 genStmt(merger, codegen, rewriter, op, topSort, exp.getValue(), 0); 1172 genResult(merger, codegen, rewriter, op); 1173 return success(); 1174 } 1175 1176 private: 1177 /// Options to control sparse code generation. 1178 SparsificationOptions options; 1179 }; 1180 1181 } // namespace 1182 1183 /// Populates the given patterns list with rewriting rules required for 1184 /// the sparsification of linear algebra operations. 1185 void mlir::populateSparsificationPatterns( 1186 RewritePatternSet &patterns, const SparsificationOptions &options) { 1187 patterns.add<GenericOpSparsifier>(patterns.getContext(), options); 1188 } 1189