1 //===- Sparsification.cpp - Implementation of sparsification --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements converting sparse tensor types to actual sparse code. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 14 #include "mlir/Dialect/Linalg/Utils/Utils.h" 15 #include "mlir/Dialect/MemRef/IR/MemRef.h" 16 #include "mlir/Dialect/SCF/SCF.h" 17 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 18 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 19 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 20 #include "mlir/Dialect/StandardOps/IR/Ops.h" 21 #include "mlir/Dialect/Vector/VectorOps.h" 22 #include "mlir/IR/Matchers.h" 23 #include "mlir/IR/TensorEncoding.h" 24 #include "llvm/ADT/SmallBitVector.h" 25 26 using namespace mlir; 27 using namespace mlir::sparse_tensor; 28 29 namespace { 30 31 // Code generation. 32 struct CodeGen { 33 CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops) 34 : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors), 35 pointers(numTensors, std::vector<Value>(numLoops)), 36 indices(numTensors, std::vector<Value>(numLoops)), 37 highs(numTensors, std::vector<Value>(numLoops)), 38 pidxs(numTensors, std::vector<Value>(numLoops)), 39 idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(), 40 curVecLength(1), curVecMask() {} 41 /// Sparsification options. 42 SparsificationOptions options; 43 /// Universal dense indices and upper bounds (by index). The loops array 44 /// is updated with the value of the universal dense index in the current 45 /// loop. The sizes array is set once with the inferred dimension sizes. 46 std::vector<Value> loops; 47 std::vector<Value> sizes; 48 /// Buffers for storing dense and sparse numerical values (by tensor). 49 /// This array is set once during bufferization of all tensors. 50 std::vector<Value> buffers; 51 /// Sparse storage schemes (1-D): pointers and indices (by tensor and index). 52 /// This array is set once during bufferization of all sparse tensors. 53 std::vector<std::vector<Value>> pointers; 54 std::vector<std::vector<Value>> indices; 55 /// Sparse iteration information (by tensor and index). These arrays 56 /// are updated to remain current within the current loop. 57 std::vector<std::vector<Value>> highs; 58 std::vector<std::vector<Value>> pidxs; 59 std::vector<std::vector<Value>> idxs; 60 /// Current reduction, updated during code generation. When indices of a 61 /// reduction are exhausted, all inner loops can "scalarize" the reduction. 62 // TODO: currently only done for (a chain of) innermost for-loops, where it 63 // is most effective; we could generalize to more outer and while-loops. 64 unsigned redExp; 65 Value redVal; 66 // Current vector length and mask. 67 unsigned curVecLength; 68 Value curVecMask; 69 }; 70 71 } // namespace 72 73 // Helper method to apply dimension ordering permutation. 74 static unsigned perm(SparseTensorEncodingAttr &enc, unsigned d) { 75 if (enc) { 76 auto order = enc.getDimOrdering(); 77 if (order) { 78 assert(order.isPermutation()); 79 return order.getDimPosition(d); 80 } 81 } 82 return d; 83 } 84 85 // Helper method to translate dim level type to internal representation. 86 static Dim toDim(SparseTensorEncodingAttr &enc, unsigned d) { 87 if (enc) { 88 SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d]; 89 if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed) 90 return Dim::kSparse; 91 if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton) 92 return Dim::kSingle; 93 } 94 return Dim::kDense; 95 } 96 97 /// Helper method to inspect sparse encodings in the tensor types. 98 /// Fills the per-dimension sparsity information for all tensors. 99 static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) { 100 bool annotated = false; 101 for (OpOperand *t : op.getInputAndOutputOperands()) { 102 auto map = op.getTiedIndexingMap(t); 103 if (!map.isProjectedPermutation()) 104 return false; 105 auto enc = getSparseTensorEncoding(t->get().getType()); 106 if (enc) 107 annotated = true; 108 assert(map.getNumResults() == op.getRank(t)); 109 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 110 unsigned idx = map.getDimPosition(perm(enc, d)); 111 merger.setDim(t->getOperandNumber(), idx, toDim(enc, d)); 112 } 113 } 114 return annotated; 115 } 116 117 /// A DFS helper to compute a topological sort. Note that recursion is 118 /// bounded by the number of implicit loops, which is always small. 119 /// Returns false when a cycle is detected. 120 static bool topSortDFS(unsigned i, std::vector<unsigned> &visit, 121 std::vector<unsigned> &topSort, 122 std::vector<std::vector<bool>> &adjM) { 123 if (visit[i] != 0) 124 return visit[i] != 1; // 1 denotes cycle! 125 visit[i] = 1; 126 for (unsigned j = 0, e = visit.size(); j < e; j++) 127 if (adjM[i][j]) 128 if (!topSortDFS(j, visit, topSort, adjM)) 129 return false; 130 visit[i] = 2; 131 topSort.push_back(i); 132 return true; 133 } 134 135 /// Computes a topologically sorted iteration graph for the linalg operation. 136 /// Ensures all tensors are visited in natural index order. This is essential 137 /// for sparse storage formats since these only support access along fixed 138 /// dimensions. Even for dense storage formats, however, the natural index 139 /// order yields innermost unit-stride access with better spatial locality. 140 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op, 141 std::vector<unsigned> &topSort, 142 bool sparseOnly) { 143 // Set up an n x n from/to adjacency matrix of the iteration graph 144 // for the implicit loop indices i_0 .. i_n-1. 145 unsigned n = op.getNumLoops(); 146 std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false)); 147 148 // Iterate over the indexing maps of every tensor in the tensor expression. 149 for (OpOperand *t : op.getInputAndOutputOperands()) { 150 auto map = op.getTiedIndexingMap(t); 151 auto enc = getSparseTensorEncoding(t->get().getType()); 152 assert(map.getNumDims() == n); 153 // Skip dense tensor constraints when sparse only is requested. 154 if (sparseOnly && !enc) 155 continue; 156 // Each tensor expression and optional dimension ordering (row-major 157 // by default) puts an ordering constraint on the loop indices. For 158 // example, the tensor expresion A_ijk forces the ordering i < j < k 159 // on the loop indices if no explicit dimension ordering is given. 160 for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) { 161 unsigned f = map.getDimPosition(perm(enc, d - 1)); 162 unsigned t = map.getDimPosition(perm(enc, d)); 163 adjM[f][t] = true; 164 } 165 } 166 167 // Topologically sort the iteration graph to determine loop order. 168 // Report failure for a cyclic iteration graph. 169 topSort.clear(); 170 topSort.reserve(n); 171 std::vector<unsigned> visit(n, 0); 172 for (unsigned i = 0; i < n; i++) 173 if (visit[i] == 0) 174 if (!topSortDFS(i, visit, topSort, adjM)) 175 return false; // cycle! 176 std::reverse(std::begin(topSort), std::end(topSort)); 177 return true; 178 } 179 180 /// Returns true when the tensor expression is admissable for codegen. 181 /// Since all sparse input tensors are admissable, we just need to check 182 /// whether the output tensor in the tensor expression codegen is admissable. 183 static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op, 184 unsigned exp) { 185 OpOperand *lhs = op.getOutputOperand(0); 186 unsigned tensor = lhs->getOperandNumber(); 187 auto enc = getSparseTensorEncoding(lhs->get().getType()); 188 // An non-annotated output tensor is assumed dense, and becomes a random 189 // access n-dim memref. Admissable since inserstions cannot occur. 190 if (!enc) 191 return true; 192 // An all-dense annotated "sparse" output tensor becomes a linearized random 193 // access 1-dim memref. Also admissable since insertions cannot occur. 194 bool allDense = true; 195 unsigned numLoops = op.iterator_types().getValue().size(); 196 for (unsigned i = 0; i < numLoops; i++) 197 if (merger.isDim(tensor, i, Dim::kSparse)) { 198 allDense = false; 199 break; 200 } 201 if (allDense) 202 return true; 203 // A tensor expression with a sparse output tensor that changes its values 204 // but not its nonzero structure, an operation called "simply dynamic" in 205 // [Bik96,Ch9], is also admissable without special codegen. 206 if (merger.isConjunction(tensor, exp)) 207 return true; 208 // Reject for now since this requires changes to the nonzero structure. 209 // TODO: implement "workspaces" [Kjolstad2019] 210 return false; 211 } 212 213 /// Maps sparse integer option to actual integral storage type. 214 static Type genIntType(PatternRewriter &rewriter, unsigned width) { 215 if (width == 0) 216 return rewriter.getIndexType(); 217 return rewriter.getIntegerType(width); 218 } 219 220 /// Detects in-place annotation on tensor argument. 221 static bool getInPlace(Value val) { 222 if (auto arg = val.dyn_cast<BlockArgument>()) 223 if (auto funcOp = dyn_cast<FuncOp>(arg.getOwner()->getParentOp())) 224 if (auto attr = funcOp.getArgAttrOfType<BoolAttr>( 225 arg.getArgNumber(), linalg::LinalgDialect::kInplaceableAttrName)) 226 return attr.getValue(); 227 return false; 228 } 229 230 /// Generates buffer for the output tensor. 231 static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter, 232 linalg::GenericOp op, MemRefType denseTp, 233 ArrayRef<Value> args) { 234 Location loc = op.getLoc(); 235 Value tensor = op.getOutputOperand(0)->get(); 236 // The output tensor simply could materialize from the buffer that will 237 // be generated for the tensor present in the outs() clause. This has 238 // the major advantage that the sparse kernel only updates the nonzero 239 // positions for the output tensor. 240 if (getInPlace(tensor)) 241 return rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 242 // By default, a new buffer is allocated which is initialized to the 243 // tensor defined in the outs() clause. This is always correct but 244 // introduces a dense initialization component that may negatively 245 // impact the running complexity of the sparse kernel. 246 Value init = rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 247 Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args); 248 rewriter.create<memref::CopyOp>(loc, init, alloc); 249 return alloc; 250 } 251 252 /// Local bufferization of all dense and sparse data structures. 253 /// This code enables testing the first prototype sparse compiler. 254 // TODO: replace this with a proliferated bufferization strategy 255 static bool genBuffers(Merger &merger, CodeGen &codegen, 256 PatternRewriter &rewriter, linalg::GenericOp op) { 257 Location loc = op.getLoc(); 258 assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1); 259 // For every tensor, find lower and upper bound on dimensions, set the 260 // same bounds on loop indices, and obtain dense or sparse buffer(s). 261 SmallVector<Value, 4> args; 262 for (OpOperand *t : op.getInputAndOutputOperands()) { 263 unsigned tensor = t->getOperandNumber(); 264 auto shape = op.getShape(t); 265 auto map = op.getTiedIndexingMap(t); 266 auto enc = getSparseTensorEncoding(t->get().getType()); 267 // Scan all dimensions of current tensor. 268 args.clear(); 269 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 270 unsigned idx = map.getDimPosition(perm(enc, d)); 271 // Handle sparse storage schemes. 272 if (merger.isDim(tensor, idx, Dim::kSparse)) { 273 auto dynShape = {ShapedType::kDynamicSize}; 274 auto ptrTp = MemRefType::get( 275 dynShape, genIntType(rewriter, enc.getPointerBitWidth())); 276 auto indTp = MemRefType::get( 277 dynShape, genIntType(rewriter, enc.getIndexBitWidth())); 278 Value dim = rewriter.create<ConstantIndexOp>(loc, d); 279 // Generate sparse primitives to obtains pointer and indices. 280 codegen.pointers[tensor][idx] = 281 rewriter.create<ToPointersOp>(loc, ptrTp, t->get(), dim); 282 codegen.indices[tensor][idx] = 283 rewriter.create<ToIndicesOp>(loc, indTp, t->get(), dim); 284 } 285 // Find lower and upper bound in current dimension. Note that a 286 // permuted encoding queries static type dimensions accordingly, 287 // but queries dynamic type dimensions in the generated order. 288 Value up; 289 unsigned p = perm(enc, d); 290 if (shape[p] == MemRefType::kDynamicSize) { 291 up = rewriter.create<tensor::DimOp>(loc, t->get(), d); 292 args.push_back(up); 293 } else { 294 up = rewriter.create<ConstantIndexOp>(loc, shape[p]); 295 } 296 assert(codegen.highs[tensor][idx] == nullptr); 297 codegen.sizes[idx] = codegen.highs[tensor][idx] = up; 298 } 299 // Perform the required bufferization. Dense inputs materialize 300 // from the input tensors. Dense outputs need special handling. 301 // Sparse inputs use sparse primitives to obtain the values. 302 // We also accept in-place all-dense annotated "sparse" outputs. 303 Type elementType = getElementTypeOrSelf(t->get().getType()); 304 if (!enc) { 305 // Non-annotated dense tensors. 306 auto denseTp = MemRefType::get(shape, elementType); 307 if (tensor < op.getNumInputs()) 308 codegen.buffers[tensor] = 309 rewriter.create<memref::BufferCastOp>(loc, denseTp, t->get()); 310 else 311 codegen.buffers[tensor] = 312 genOutputBuffer(codegen, rewriter, op, denseTp, args); 313 } else { 314 // Annotated sparse tensors. 315 if (tensor == op.getNumInputs() && !getInPlace(t->get())) 316 return false; // reject output if not in-place 317 auto dynShape = {ShapedType::kDynamicSize}; 318 auto sparseTp = MemRefType::get(dynShape, elementType); 319 codegen.buffers[tensor] = 320 rewriter.create<ToValuesOp>(loc, sparseTp, t->get()); 321 } 322 } 323 return true; 324 } 325 326 /// Constructs vector type. 327 static VectorType vectorType(CodeGen &codegen, Type etp) { 328 return VectorType::get(codegen.curVecLength, etp); 329 } 330 331 /// Constructs vector type from pointer. 332 static VectorType vectorType(CodeGen &codegen, Value ptr) { 333 return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType()); 334 } 335 336 /// Constructs vector iteration mask. 337 static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter, 338 Value iv, Value lo, Value hi, Value step) { 339 Location loc = iv.getLoc(); 340 VectorType mtp = vectorType(codegen, rewriter.getIntegerType(1)); 341 // Special case if the vector length evenly divides the trip count (for 342 // example, "for i = 0, 128, 16"). A constant all-true mask is generated 343 // so that all subsequent masked memory operations are immediately folded 344 // into unconditional memory operations. 345 IntegerAttr loInt, hiInt, stepInt; 346 if (matchPattern(lo, m_Constant(&loInt)) && 347 matchPattern(hi, m_Constant(&hiInt)) && 348 matchPattern(step, m_Constant(&stepInt))) { 349 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) 350 return rewriter.create<vector::BroadcastOp>( 351 loc, mtp, rewriter.create<ConstantIntOp>(loc, 1, 1)); 352 } 353 // Otherwise, generate a vector mask that avoids overrunning the upperbound 354 // during vector execution. Here we rely on subsequent loop optimizations to 355 // avoid executing the mask in all iterations, for example, by splitting the 356 // loop into an unconditional vector loop and a scalar cleanup loop. 357 Value end = rewriter.create<SubIOp>(loc, hi, iv); 358 return rewriter.create<vector::CreateMaskOp>(loc, mtp, end); 359 } 360 361 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi]. 362 static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter, 363 Value ptr, ArrayRef<Value> args) { 364 Location loc = ptr.getLoc(); 365 VectorType vtp = vectorType(codegen, ptr); 366 Value pass = rewriter.create<ConstantOp>(loc, vtp, rewriter.getZeroAttr(vtp)); 367 if (args.back().getType().isa<VectorType>()) { 368 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 369 Value indexVec = args.back(); 370 scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0); 371 return rewriter.create<vector::GatherOp>( 372 loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass); 373 } 374 return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args, 375 codegen.curVecMask, pass); 376 } 377 378 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs. 379 static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter, 380 Value rhs, Value ptr, ArrayRef<Value> args) { 381 Location loc = ptr.getLoc(); 382 if (args.back().getType().isa<VectorType>()) { 383 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 384 Value indexVec = args.back(); 385 scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0); 386 rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, 387 codegen.curVecMask, rhs); 388 return; 389 } 390 rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask, 391 rhs); 392 } 393 394 /// Generates a vectorized invariant. Here we rely on subsequent loop 395 /// optimizations to hoist the invariant broadcast out of the vector loop. 396 static Value genVectorInvariantValue(CodeGen &codegen, 397 PatternRewriter &rewriter, Value val) { 398 VectorType vtp = vectorType(codegen, val.getType()); 399 return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val); 400 } 401 402 /// Generates a load on a dense or sparse tensor. 403 static Value genTensorLoad(Merger &merger, CodeGen &codegen, 404 PatternRewriter &rewriter, linalg::GenericOp op, 405 unsigned exp) { 406 // Test if the load was hoisted to a higher loop nest. 407 Value val = merger.exp(exp).val; 408 if (val) { 409 if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>()) 410 return genVectorInvariantValue(codegen, rewriter, val); 411 return val; 412 } 413 // Actual load. 414 SmallVector<Value, 4> args; 415 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 416 unsigned tensor = t->getOperandNumber(); 417 auto map = op.getTiedIndexingMap(t); 418 auto enc = getSparseTensorEncoding(t->get().getType()); 419 unsigned rank = map.getNumResults(); 420 if (enc) { 421 unsigned idx = map.getDimPosition(perm(enc, rank - 1)); 422 assert(codegen.pidxs[tensor][idx] != nullptr); 423 args.push_back(codegen.pidxs[tensor][idx]); // position index 424 } else { 425 for (unsigned d = 0; d < rank; d++) { 426 unsigned idx = map.getDimPosition(d); 427 args.push_back(codegen.loops[idx]); // universal dense index 428 } 429 } 430 Location loc = op.getLoc(); 431 Value ptr = codegen.buffers[tensor]; 432 if (codegen.curVecLength > 1) 433 return genVectorLoad(codegen, rewriter, ptr, args); 434 return rewriter.create<memref::LoadOp>(loc, ptr, args); 435 } 436 437 /// Generates a store on a dense or sparse tensor. 438 static void genTensorStore(Merger &merger, CodeGen &codegen, 439 PatternRewriter &rewriter, linalg::GenericOp op, 440 OpOperand *t, Value rhs) { 441 Location loc = op.getLoc(); 442 // Test if this is a scalarized reduction. 443 OpOperand *lhs = op.getOutputOperand(0); 444 if (lhs == t && codegen.redVal) { 445 if (codegen.curVecLength > 1) 446 rhs = rewriter.create<SelectOp>(loc, codegen.curVecMask, rhs, 447 codegen.redVal); 448 codegen.redVal = rhs; 449 return; 450 } 451 // Actual store. 452 SmallVector<Value, 4> args; 453 unsigned tensor = t->getOperandNumber(); 454 auto map = op.getTiedIndexingMap(t); 455 auto enc = getSparseTensorEncoding(t->get().getType()); 456 unsigned rank = map.getNumResults(); 457 if (enc) { 458 unsigned idx = map.getDimPosition(perm(enc, rank - 1)); 459 assert(codegen.pidxs[tensor][idx] != nullptr); 460 args.push_back(codegen.pidxs[tensor][idx]); // position index 461 } else { 462 for (unsigned d = 0; d < rank; d++) { 463 unsigned idx = map.getDimPosition(d); 464 args.push_back(codegen.loops[idx]); // universal dense index 465 } 466 } 467 Value ptr = codegen.buffers[tensor]; 468 if (codegen.curVecLength > 1) 469 genVectorStore(codegen, rewriter, rhs, ptr, args); 470 else 471 rewriter.create<memref::StoreOp>(loc, rhs, ptr, args); 472 } 473 474 /// Generates a pointer/index load from the sparse storage scheme. Narrower 475 /// data types need to be zero extended before casting the value into the 476 /// index type used for looping and indexing. 477 static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc, 478 Value ptr, Value s) { 479 // See https://llvm.org/docs/GetElementPtr.html for some background on 480 // the complications described below. 481 if (codegen.curVecLength > 1) { 482 // Since the index vector is used in a subsequent gather/scatter operations, 483 // which effectively defines an unsigned pointer + signed index, we must 484 // zero extend the vector to an index width. For 8-bit and 16-bit values, 485 // an 32-bit index width suffices. For 32-bit values, zero extending the 486 // elements into 64-bit loses some performance since the 32-bit indexed 487 // gather/scatter is more efficient than the 64-bit index variant (if the 488 // negative 32-bit index space is unused, the enableSIMDIndex32 flag can 489 // preserve this performance). For 64-bit values, there is no good way 490 // to state that the indices are unsigned, with creates the potential of 491 // incorrect address calculations in the unlikely case we need such 492 // extremely large offsets. 493 Type etp = ptr.getType().cast<MemRefType>().getElementType(); 494 Value vload = genVectorLoad(codegen, rewriter, ptr, {s}); 495 if (!etp.isa<IndexType>()) { 496 if (etp.getIntOrFloatBitWidth() < 32) 497 vload = rewriter.create<ZeroExtendIOp>( 498 loc, vload, vectorType(codegen, rewriter.getIntegerType(32))); 499 else if (etp.getIntOrFloatBitWidth() < 64 && 500 !codegen.options.enableSIMDIndex32) 501 vload = rewriter.create<ZeroExtendIOp>( 502 loc, vload, vectorType(codegen, rewriter.getIntegerType(64))); 503 } 504 return vload; 505 } 506 // For the scalar case, we simply zero extend narrower indices into 64-bit 507 // values before casting to index without a performance penalty. Here too, 508 // however, indices that already are 64-bit, in theory, cannot express the 509 // full range as explained above. 510 Value load = rewriter.create<memref::LoadOp>(loc, ptr, s); 511 if (!load.getType().isa<IndexType>()) { 512 if (load.getType().getIntOrFloatBitWidth() < 64) 513 load = rewriter.create<ZeroExtendIOp>(loc, load, 514 rewriter.getIntegerType(64)); 515 load = rewriter.create<IndexCastOp>(loc, load, rewriter.getIndexType()); 516 } 517 return load; 518 } 519 520 /// Generates an invariant value. 521 static Value genInvariantValue(Merger &merger, CodeGen &codegen, 522 PatternRewriter &rewriter, unsigned exp) { 523 Value val = merger.exp(exp).val; 524 if (codegen.curVecLength > 1) 525 return genVectorInvariantValue(codegen, rewriter, val); 526 return val; 527 } 528 529 /// Generates an address computation "sz * p + i". 530 static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter, 531 Location loc, Value size, Value p, Value i) { 532 Value mul = rewriter.create<MulIOp>(loc, size, p); 533 if (auto vtp = i.getType().dyn_cast<VectorType>()) { 534 Value inv = rewriter.create<IndexCastOp>(loc, mul, vtp.getElementType()); 535 mul = genVectorInvariantValue(codegen, rewriter, inv); 536 } 537 return rewriter.create<AddIOp>(loc, mul, i); 538 } 539 540 /// Generates start of a reduction. 541 static Value genReductionStart(Merger &merger, CodeGen &codegen, 542 PatternRewriter &rewriter, 543 linalg::GenericOp op) { 544 if (codegen.redVal) 545 return codegen.redVal; // chained with previous for-loop 546 if (codegen.curVecLength > 1) { 547 // TODO: assumes + reductions for now 548 VectorType vtp = vectorType(codegen, codegen.buffers[codegen.redExp]); 549 return rewriter.create<ConstantOp>(op.getLoc(), vtp, 550 rewriter.getZeroAttr(vtp)); 551 } 552 return genTensorLoad(merger, codegen, rewriter, op, codegen.redExp); 553 } 554 555 /// Generates end of a reduction. 556 static void genReductionEnd(Merger &merger, CodeGen &codegen, 557 PatternRewriter &rewriter, linalg::GenericOp op) { 558 Value red = codegen.redVal; 559 if (!red) 560 return; 561 assert(codegen.curVecLength == 1); 562 codegen.redVal = merger.exp(codegen.redExp).val = Value(); // end chain 563 OpOperand *lhs = op.getOutputOperand(0); 564 if (auto vtp = red.getType().dyn_cast<VectorType>()) { 565 // TODO: assumes + reductions for now 566 StringAttr kind = rewriter.getStringAttr("add"); 567 Value ld = genTensorLoad(merger, codegen, rewriter, op, codegen.redExp); 568 // Integer reductions don't accept an accumulator. 569 if (vtp.getElementType().isa<IntegerType>()) { 570 red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(), 571 kind, red, ValueRange{}); 572 red = rewriter.create<AddIOp>(op.getLoc(), red, ld); 573 } else { 574 red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(), 575 kind, red, ld); 576 } 577 } 578 genTensorStore(merger, codegen, rewriter, op, lhs, red); 579 } 580 581 /// Recursively generates tensor expression. 582 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 583 linalg::GenericOp op, unsigned exp) { 584 Location loc = op.getLoc(); 585 if (exp == -1u) 586 return Value(); 587 if (merger.exp(exp).kind == Kind::kTensor) 588 return genTensorLoad(merger, codegen, rewriter, op, exp); 589 if (merger.exp(exp).kind == Kind::kInvariant) 590 return genInvariantValue(merger, codegen, rewriter, exp); 591 Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0); 592 Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1); 593 if (merger.exp(exp).kind == Kind::kNegI) { 594 // TODO: no negi in std, need to make zero explicit. 595 Type tp = op.getOutputTensorTypes()[0].getElementType(); 596 v1 = v0; 597 v0 = rewriter.create<ConstantOp>(loc, tp, rewriter.getZeroAttr(tp)); 598 if (codegen.curVecLength > 1) 599 v0 = genVectorInvariantValue(codegen, rewriter, v0); 600 } 601 return merger.buildExp(rewriter, loc, exp, v0, v1); 602 } 603 604 /// Hoists loop invariant tensor loads for which indices have been exhausted. 605 static void genInvariants(Merger &merger, CodeGen &codegen, 606 PatternRewriter &rewriter, linalg::GenericOp op, 607 unsigned exp, unsigned ldx, bool hoist) { 608 if (exp == -1u) 609 return; 610 if (merger.exp(exp).kind == Kind::kTensor) { 611 // Inspect tensor indices. 612 bool atLevel = ldx == -1u; 613 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 614 auto map = op.getTiedIndexingMap(t); 615 auto enc = getSparseTensorEncoding(t->get().getType()); 616 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 617 unsigned idx = map.getDimPosition(perm(enc, d)); 618 if (!codegen.loops[idx]) 619 return; // still in play 620 else if (idx == ldx) 621 atLevel = true; 622 } 623 // All exhausted at this level (atLevel denotes exactly at this level). 624 OpOperand *lhs = op.getOutputOperand(0); 625 if (lhs == t) { 626 codegen.redExp = hoist ? exp : -1u; 627 } else if (atLevel) { 628 merger.exp(exp).val = 629 hoist ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value(); 630 } 631 } else if (merger.exp(exp).kind != Kind::kInvariant) { 632 // Traverse into the binary operations. Note that we only hoist 633 // tensor loads, since subsequent MLIR/LLVM passes know how to 634 // deal with all other kinds of derived loop invariants. 635 unsigned e0 = merger.exp(exp).children.e0; 636 unsigned e1 = merger.exp(exp).children.e1; 637 genInvariants(merger, codegen, rewriter, op, e0, ldx, hoist); 638 genInvariants(merger, codegen, rewriter, op, e1, ldx, hoist); 639 } 640 } 641 642 /// Generates initialization code for the subsequent loop sequence at 643 /// current index level. Returns true if the loop sequence needs to 644 /// maintain the universal index. 645 static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 646 linalg::GenericOp op, std::vector<unsigned> &topSort, 647 unsigned at, llvm::BitVector &inits) { 648 bool needsUniv = false; 649 Location loc = op.getLoc(); 650 unsigned idx = topSort[at]; 651 652 // Initialize sparse positions. 653 for (unsigned b = 0, be = inits.size(); b < be; b++) { 654 if (inits[b]) { 655 unsigned tensor = merger.tensor(b); 656 assert(idx == merger.index(b)); 657 if (merger.isDim(b, Dim::kSparse)) { 658 // Initialize sparse index. 659 unsigned pat = at; 660 for (; pat != 0; pat--) { 661 if (codegen.pidxs[tensor][topSort[pat - 1]]) 662 break; 663 } 664 Value ptr = codegen.pointers[tensor][idx]; 665 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 666 Value p0 = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0) 667 : codegen.pidxs[tensor][topSort[pat - 1]]; 668 codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0); 669 Value p1 = rewriter.create<AddIOp>(loc, p0, one); 670 codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1); 671 } else { 672 // Dense index still in play. 673 needsUniv = true; 674 } 675 } 676 } 677 678 // Initialize the universal dense index. 679 codegen.loops[idx] = rewriter.create<ConstantIndexOp>(loc, 0); 680 return needsUniv; 681 } 682 683 /// Returns vectorization strategy. Any implicit inner loop in the Linalg 684 /// operation is a candidate. Whether it is actually converted to SIMD code 685 /// depends on the requested strategy. 686 static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) { 687 switch (codegen.options.vectorizationStrategy) { 688 case SparseVectorizationStrategy::kNone: 689 return false; 690 case SparseVectorizationStrategy::kDenseInnerLoop: 691 return isInner && !isSparse; 692 case SparseVectorizationStrategy::kAnyStorageInnerLoop: 693 return isInner; 694 } 695 llvm_unreachable("unexpected vectorization strategy"); 696 } 697 698 /// Returns parallelization strategy. Any implicit loop in the Linalg operation 699 /// that is marked "parallel" is a candidate. Whether it is actually converted 700 /// to a parallel operation depends on the requested strategy. 701 static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction, 702 bool isSparse, bool isVector) { 703 switch (codegen.options.parallelizationStrategy) { 704 case SparseParallelizationStrategy::kNone: 705 return false; 706 case SparseParallelizationStrategy::kDenseOuterLoop: 707 return isOuter && !isSparse && !isReduction && !isVector; 708 case SparseParallelizationStrategy::kAnyStorageOuterLoop: 709 return isOuter && !isReduction && !isVector; 710 case SparseParallelizationStrategy::kDenseAnyLoop: 711 return !isSparse && !isReduction && !isVector; 712 case SparseParallelizationStrategy::kAnyStorageAnyLoop: 713 return !isReduction && !isVector; 714 } 715 llvm_unreachable("unexpected parallelization strategy"); 716 } 717 718 /// Checks unit strides for dense tensors. The iteration graph may have ignored 719 /// dense access patterns in order to avoid cycles (sparse access patterns are 720 /// always placed innermost), but that means dense access has become strided. 721 /// For now, we reject vectorization of such cases. 722 /// TODO: implement strided load/stores on dense arrays 723 static bool denseUnitStrides(Merger &merger, linalg::GenericOp op, 724 unsigned idx) { 725 for (OpOperand *t : op.getInputAndOutputOperands()) { 726 if (!getSparseTensorEncoding(t->get().getType())) { 727 auto map = op.getTiedIndexingMap(t); 728 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 729 if (map.getDimPosition(d) == idx && d != rank - 1) 730 return false; 731 } 732 } 733 } 734 return true; 735 } 736 737 /// Generates a for-loop on a single index. 738 static Operation *genFor(Merger &merger, CodeGen &codegen, 739 PatternRewriter &rewriter, linalg::GenericOp op, 740 bool isOuter, bool isInner, unsigned idx, 741 llvm::BitVector &indices) { 742 unsigned fb = indices.find_first(); 743 unsigned tensor = merger.tensor(fb); 744 assert(idx == merger.index(fb)); 745 auto iteratorTypes = op.iterator_types().getValue(); 746 bool isReduction = linalg::isReductionIteratorType(iteratorTypes[idx]); 747 bool isSparse = merger.isDim(fb, Dim::kSparse); 748 bool isVector = isVectorFor(codegen, isInner, isSparse) && 749 denseUnitStrides(merger, op, idx); 750 bool isParallel = 751 isParallelFor(codegen, isOuter, isReduction, isSparse, isVector); 752 753 // Prepare vector length. 754 if (isVector) 755 codegen.curVecLength = codegen.options.vectorLength; 756 757 // Loop bounds and increment. 758 Location loc = op.getLoc(); 759 Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx]; 760 Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx]; 761 Value step = rewriter.create<ConstantIndexOp>(loc, codegen.curVecLength); 762 763 // Emit a parallel loop. 764 if (isParallel) { 765 assert(!isVector); 766 scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step); 767 if (isSparse) 768 codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0]; 769 else 770 codegen.loops[idx] = parOp.getInductionVars()[0]; 771 rewriter.setInsertionPointToStart(parOp.getBody()); 772 return parOp; 773 } 774 775 // Emit a sequential loop, potentially with a scalarized reduction. 776 bool scalarRed = isInner && codegen.redExp != -1u; 777 SmallVector<Value, 4> operands; 778 if (scalarRed) { 779 Value load = genReductionStart(merger, codegen, rewriter, op); 780 operands.push_back(load); 781 } 782 scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands); 783 if (scalarRed) { 784 codegen.redVal = merger.exp(codegen.redExp).val = 785 forOp.getRegionIterArgs().front(); 786 } 787 // Assign induction variable to sparse or dense index. 788 Value iv = forOp.getInductionVar(); 789 if (isSparse) 790 codegen.pidxs[tensor][idx] = iv; 791 else 792 codegen.loops[idx] = iv; 793 rewriter.setInsertionPointToStart(forOp.getBody()); 794 // Share vector iteration mask between all subsequent loads/stores. 795 if (isVector) 796 codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step); 797 return forOp; 798 } 799 800 /// Emit a while-loop for co-iteration over multiple indices. 801 static Operation *genWhile(Merger &merger, CodeGen &codegen, 802 PatternRewriter &rewriter, linalg::GenericOp op, 803 unsigned idx, bool needsUniv, 804 llvm::BitVector &indices) { 805 SmallVector<Type, 4> types; 806 SmallVector<Value, 4> operands; 807 // Construct the while-loop with a parameter for each index. 808 Type indexType = rewriter.getIndexType(); 809 for (unsigned b = 0, be = indices.size(); b < be; b++) { 810 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 811 unsigned tensor = merger.tensor(b); 812 assert(idx == merger.index(b)); 813 types.push_back(indexType); 814 assert(codegen.pidxs[tensor][idx].getType().isa<IndexType>() && 815 "type mismatch for sparse index"); 816 operands.push_back(codegen.pidxs[tensor][idx]); 817 } 818 } 819 if (needsUniv) { 820 types.push_back(indexType); 821 assert(codegen.loops[idx].getType().isa<IndexType>() && 822 "type mismatch for universal index"); 823 operands.push_back(codegen.loops[idx]); 824 } 825 Location loc = op.getLoc(); 826 scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands); 827 Block *before = rewriter.createBlock(&whileOp.before(), {}, types); 828 Block *after = rewriter.createBlock(&whileOp.after(), {}, types); 829 830 // Build the "before" region, which effectively consists 831 // of a conjunction of "i < upper" tests on all induction. 832 rewriter.setInsertionPointToStart(&whileOp.before().front()); 833 Value cond; 834 unsigned o = 0; 835 for (unsigned b = 0, be = indices.size(); b < be; b++) { 836 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 837 unsigned tensor = merger.tensor(b); 838 assert(idx == merger.index(b)); 839 Value op1 = before->getArgument(o); 840 Value op2 = codegen.highs[tensor][idx]; 841 Value opc = rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, op1, op2); 842 cond = cond ? rewriter.create<AndOp>(loc, cond, opc) : opc; 843 codegen.pidxs[tensor][idx] = after->getArgument(o++); 844 } 845 } 846 if (needsUniv) 847 codegen.loops[idx] = after->getArgument(o++); 848 assert(o == operands.size()); 849 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 850 rewriter.setInsertionPointToStart(&whileOp.after().front()); 851 return whileOp; 852 } 853 854 /// Generates a for-loop or a while-loop, depending on whether it implements 855 /// singleton iteration or co-iteration over the given conjunction. 856 static Operation *genLoop(Merger &merger, CodeGen &codegen, 857 PatternRewriter &rewriter, linalg::GenericOp op, 858 std::vector<unsigned> &topSort, unsigned at, 859 bool needsUniv, llvm::BitVector &indices) { 860 unsigned idx = topSort[at]; 861 if (indices.count() == 1) { 862 bool isOuter = at == 0; 863 bool isInner = at == topSort.size() - 1; 864 return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx, 865 indices); 866 } 867 genReductionEnd(merger, codegen, rewriter, op); // cannot chain 868 return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices); 869 } 870 871 /// Generates the local variables for this loop, consisting of the sparse 872 /// indices, restored universal dense index, and dense positions. 873 static void genLocals(Merger &merger, CodeGen &codegen, 874 PatternRewriter &rewriter, linalg::GenericOp op, 875 std::vector<unsigned> &topSort, unsigned at, 876 bool needsUniv, llvm::BitVector &locals) { 877 Location loc = op.getLoc(); 878 unsigned idx = topSort[at]; 879 880 // Initialize sparse indices. 881 Value min; 882 for (unsigned b = 0, be = locals.size(); b < be; b++) { 883 if (locals[b] && merger.isDim(b, Dim::kSparse)) { 884 unsigned tensor = merger.tensor(b); 885 assert(idx == merger.index(b)); 886 Value ptr = codegen.indices[tensor][idx]; 887 Value s = codegen.pidxs[tensor][idx]; 888 Value load = genLoad(codegen, rewriter, loc, ptr, s); 889 codegen.idxs[tensor][idx] = load; 890 if (!needsUniv) { 891 if (min) { 892 Value cmp = 893 rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, load, min); 894 min = rewriter.create<SelectOp>(loc, cmp, load, min); 895 } else { 896 min = load; 897 } 898 } 899 } 900 } 901 902 // Merge dense universal index over minimum. 903 if (min) { 904 assert(!needsUniv); 905 codegen.loops[idx] = min; 906 } 907 908 // Initialize dense positions. Note that we generate dense indices of the 909 // output tensor unconditionally, since they may not appear in the lattice, 910 // but may be needed for linearized codegen. 911 for (unsigned b = 0, be = locals.size(); b < be; b++) { 912 if ((locals[b] || merger.isOutTensor(b, idx)) && 913 merger.isDim(b, Dim::kDense)) { 914 unsigned tensor = merger.tensor(b); 915 assert(idx == merger.index(b)); 916 unsigned pat = at; 917 for (; pat != 0; pat--) 918 if (codegen.pidxs[tensor][topSort[pat - 1]]) 919 break; 920 Value p = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0) 921 : codegen.pidxs[tensor][topSort[pat - 1]]; 922 codegen.pidxs[tensor][idx] = genAddress( 923 codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]); 924 } 925 } 926 } 927 928 /// Generates the induction structure for a while-loop. 929 static void genWhileInduction(Merger &merger, CodeGen &codegen, 930 PatternRewriter &rewriter, linalg::GenericOp op, 931 unsigned idx, bool needsUniv, 932 llvm::BitVector &induction, ResultRange results) { 933 Location loc = op.getLoc(); 934 unsigned o = 0; 935 SmallVector<Value, 4> operands; 936 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 937 for (unsigned b = 0, be = induction.size(); b < be; b++) { 938 if (induction[b] && merger.isDim(b, Dim::kSparse)) { 939 unsigned tensor = merger.tensor(b); 940 assert(idx == merger.index(b)); 941 Value op1 = codegen.idxs[tensor][idx]; 942 Value op2 = codegen.loops[idx]; 943 Value op3 = codegen.pidxs[tensor][idx]; 944 Value cmp = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2); 945 Value add = rewriter.create<AddIOp>(loc, op3, one); 946 operands.push_back(rewriter.create<SelectOp>(loc, cmp, add, op3)); 947 codegen.pidxs[tensor][idx] = results[o++]; 948 } 949 } 950 if (needsUniv) { 951 operands.push_back(rewriter.create<AddIOp>(loc, codegen.loops[idx], one)); 952 codegen.loops[idx] = results[o++]; 953 } 954 assert(o == operands.size()); 955 rewriter.create<scf::YieldOp>(loc, operands); 956 } 957 958 /// Generates a single if-statement within a while-loop. 959 static scf::IfOp genIf(Merger &merger, CodeGen &codegen, 960 PatternRewriter &rewriter, linalg::GenericOp op, 961 unsigned idx, llvm::BitVector &conditions) { 962 Location loc = op.getLoc(); 963 Value cond; 964 for (unsigned b = 0, be = conditions.size(); b < be; b++) { 965 if (conditions[b]) { 966 unsigned tensor = merger.tensor(b); 967 assert(idx == merger.index(b)); 968 Value clause; 969 if (merger.isDim(b, Dim::kSparse)) { 970 Value op1 = codegen.idxs[tensor][idx]; 971 Value op2 = codegen.loops[idx]; 972 clause = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2); 973 } else { 974 clause = rewriter.create<ConstantIntOp>(loc, 1, 1); // true 975 } 976 cond = cond ? rewriter.create<AndOp>(loc, cond, clause) : clause; 977 } 978 } 979 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ true); 980 rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); 981 return ifOp; 982 } 983 984 /// Recursively generates code while computing iteration lattices in order 985 /// to manage the complexity of implementing co-iteration over unions 986 /// and intersections of sparse iterations spaces. 987 static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 988 linalg::GenericOp op, std::vector<unsigned> &topSort, 989 unsigned exp, unsigned at) { 990 // At each leaf, assign remaining tensor (sub)expression to output tensor. 991 if (at == topSort.size()) { 992 OpOperand *lhs = op.getOutputOperand(0); 993 Value rhs = genExp(merger, codegen, rewriter, op, exp); 994 genTensorStore(merger, codegen, rewriter, op, lhs, rhs); 995 return; 996 } 997 assert(codegen.curVecLength == 1); 998 999 // Construct iteration lattices for current loop index, with L0 at top. 1000 // Then emit initialization code for the loop sequence at this level. 1001 // We maintain the universal dense index if dense indices are still 1002 // in play for a non-singleton loop sequence. 1003 Location loc = op.getLoc(); 1004 unsigned idx = topSort[at]; 1005 unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx)); 1006 unsigned lsize = merger.set(lts).size(); 1007 assert(lsize != 0); 1008 unsigned l0 = merger.set(lts)[0]; 1009 unsigned ldx = at == 0 ? -1u : topSort[at - 1]; 1010 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/true); 1011 bool needsUniv = false; 1012 if (genInit(merger, codegen, rewriter, op, topSort, at, 1013 merger.lat(l0).bits)) { 1014 // Maintain the universal index only if it is actually 1015 // consumed by a subsequent lattice point. 1016 for (unsigned i = 1; i < lsize; i++) { 1017 unsigned li = merger.set(lts)[i]; 1018 if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) { 1019 needsUniv = true; 1020 break; 1021 } 1022 } 1023 } 1024 1025 // Emit a loop for every lattice point L0 >= Li. 1026 for (unsigned i = 0; i < lsize; i++) { 1027 unsigned li = merger.set(lts)[i]; 1028 1029 // Emit loop. 1030 codegen.curVecLength = 1; 1031 llvm::BitVector indices = merger.lat(li).simple; 1032 Operation *loop = 1033 genLoop(merger, codegen, rewriter, op, topSort, at, needsUniv, indices); 1034 genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, 1035 merger.lat(li).bits); 1036 1037 // Visit all lattices points with Li >= Lj to generate the 1038 // loop-body, possibly with if statements for coiteration. 1039 bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr; 1040 for (unsigned j = 0; j < lsize; j++) { 1041 unsigned lj = merger.set(lts)[j]; 1042 unsigned ej = merger.lat(lj).exp; 1043 if (li == lj || merger.latGT(li, lj)) { 1044 // Recurse into body of each branch. 1045 if (isWhile) { 1046 scf::IfOp ifOp = 1047 genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple); 1048 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1049 rewriter.setInsertionPointToStart(&ifOp.elseRegion().front()); 1050 } else { 1051 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1052 } 1053 } 1054 } 1055 1056 // Wrap-up induction and restore insertion point. 1057 if (isWhile) { 1058 scf::WhileOp whileOp = cast<scf::WhileOp>(loop); 1059 rewriter.setInsertionPointToEnd(&whileOp.after().front()); 1060 genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv, 1061 merger.lat(li).bits, whileOp.results()); 1062 } else { 1063 needsUniv = false; 1064 if (codegen.redVal) { 1065 rewriter.create<scf::YieldOp>(loc, codegen.redVal); 1066 codegen.redVal = loop->getResult(0); 1067 } 1068 } 1069 rewriter.setInsertionPointAfter(loop); 1070 } 1071 1072 // Wrap-up loop sequence. 1073 codegen.curVecLength = 1; 1074 genReductionEnd(merger, codegen, rewriter, op); 1075 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/false); 1076 codegen.loops[idx] = Value(); 1077 } 1078 1079 /// Converts the result computed by the sparse kernel into the required form. 1080 static void genResult(Merger &merger, CodeGen &codegen, 1081 PatternRewriter &rewriter, linalg::GenericOp op) { 1082 Location loc = op.getLoc(); 1083 OpOperand *lhs = op.getOutputOperand(0); 1084 Type resType = lhs->get().getType(); 1085 unsigned tensor = lhs->getOperandNumber(); 1086 auto map = op.getTiedIndexingMap(lhs); 1087 auto enc = getSparseTensorEncoding(resType); 1088 Value result = codegen.buffers.back(); // value array 1089 if (enc) { 1090 // The sparse annotation unambigiously defines the arrays needed 1091 // to "reconstruct" the sparse tensor from the storage scheme 1092 // (even though lowering should never need this eventually). 1093 SmallVector<Value, 4> args; 1094 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 1095 unsigned idx = map.getDimPosition(perm(enc, d)); 1096 if (merger.isDim(tensor, idx, Dim::kSparse)) { 1097 args.push_back(codegen.pointers[tensor][idx]); 1098 args.push_back(codegen.indices[tensor][idx]); 1099 } 1100 } 1101 args.push_back(result); 1102 result = rewriter.create<ToTensorOp>(loc, resType, args); 1103 } else { 1104 // To "reconstruct" an non-annotated tensor, sipmly load it 1105 // from the bufferized value. 1106 result = rewriter.create<memref::TensorLoadOp>(loc, resType, result); 1107 } 1108 rewriter.replaceOp(op, result); 1109 } 1110 1111 namespace { 1112 1113 /// Sparse rewriting rule for generic Lingalg operation. 1114 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> { 1115 public: 1116 GenericOpSparsifier(MLIRContext *context, SparsificationOptions o) 1117 : OpRewritePattern<linalg::GenericOp>(context), options(o) {} 1118 1119 LogicalResult matchAndRewrite(linalg::GenericOp op, 1120 PatternRewriter &rewriter) const override { 1121 // Detects sparse annotations and translate the per-dimension sparsity 1122 // information for all tensors to loop indices in the kernel. 1123 assert(op.getNumOutputs() == 1); 1124 unsigned numTensors = op.getNumInputsAndOutputs(); 1125 unsigned numLoops = op.iterator_types().getValue().size(); 1126 Merger merger(numTensors, numLoops); 1127 if (!findSparseAnnotations(merger, op)) 1128 return failure(); 1129 1130 // Computes a topologically sorted iteration graph to ensure 1131 // tensors are visited in natural index order. Fails on cycles. 1132 // This assumes that higher-level passes have already put the 1133 // tensors in each tensor expression in a feasible order. 1134 std::vector<unsigned> topSort; 1135 if (!computeIterationGraph(merger, op, topSort, /*sparseOnly=*/false) && 1136 !computeIterationGraph(merger, op, topSort, /*sparseOnly=*/true)) 1137 return failure(); 1138 1139 // Builds the tensor expression for the Linalg operation in SSA form. 1140 Optional<unsigned> exp = merger.buildTensorExpFromLinalg(op); 1141 if (!exp.hasValue()) 1142 return failure(); 1143 1144 // Rejects an inadmissable tensor expression. 1145 if (!isAdmissableTensorExp(merger, op, exp.getValue())) 1146 return failure(); 1147 1148 // Recursively generates code. 1149 CodeGen codegen(options, numTensors, numLoops); 1150 if (!genBuffers(merger, codegen, rewriter, op)) 1151 return failure(); // could not bufferize 1152 genStmt(merger, codegen, rewriter, op, topSort, exp.getValue(), 0); 1153 genResult(merger, codegen, rewriter, op); 1154 return success(); 1155 } 1156 1157 private: 1158 /// Options to control sparse code generation. 1159 SparsificationOptions options; 1160 }; 1161 1162 } // namespace 1163 1164 /// Populates the given patterns list with rewriting rules required for 1165 /// the sparsification of linear algebra operations. 1166 void mlir::populateSparsificationPatterns( 1167 RewritePatternSet &patterns, const SparsificationOptions &options) { 1168 patterns.add<GenericOpSparsifier>(patterns.getContext(), options); 1169 } 1170