1 //===- Sparsification.cpp - Implementation of sparsification --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements converting sparse tensor types to actual sparse code. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 15 #include "mlir/Dialect/Linalg/Utils/Utils.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/SCF/SCF.h" 18 #include "mlir/Dialect/SCF/Transforms.h" 19 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 20 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 21 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 22 #include "mlir/Dialect/StandardOps/IR/Ops.h" 23 #include "mlir/Dialect/Vector/VectorOps.h" 24 #include "mlir/IR/Matchers.h" 25 #include "mlir/IR/TensorEncoding.h" 26 #include "llvm/ADT/SmallBitVector.h" 27 28 using namespace mlir; 29 using namespace mlir::sparse_tensor; 30 31 namespace { 32 33 // Code generation. 34 struct CodeGen { 35 CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops) 36 : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors), 37 pointers(numTensors, std::vector<Value>(numLoops)), 38 indices(numTensors, std::vector<Value>(numLoops)), 39 highs(numTensors, std::vector<Value>(numLoops)), 40 pidxs(numTensors, std::vector<Value>(numLoops)), 41 idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(), 42 curVecLength(1), curVecMask() {} 43 /// Sparsification options. 44 SparsificationOptions options; 45 /// Universal dense indices and upper bounds (by index). The loops array 46 /// is updated with the value of the universal dense index in the current 47 /// loop. The sizes array is set once with the inferred dimension sizes. 48 std::vector<Value> loops; 49 std::vector<Value> sizes; 50 /// Buffers for storing dense and sparse numerical values (by tensor). 51 /// This array is set once during bufferization of all tensors. 52 std::vector<Value> buffers; 53 /// Sparse storage schemes (1-D): pointers and indices (by tensor and index). 54 /// This array is set once during bufferization of all sparse tensors. 55 std::vector<std::vector<Value>> pointers; 56 std::vector<std::vector<Value>> indices; 57 /// Sparse iteration information (by tensor and index). These arrays 58 /// are updated to remain current within the current loop. 59 std::vector<std::vector<Value>> highs; 60 std::vector<std::vector<Value>> pidxs; 61 std::vector<std::vector<Value>> idxs; 62 /// Current reduction, updated during code generation. When indices of a 63 /// reduction are exhausted, all inner loops can "scalarize" the reduction. 64 // TODO: currently only done for (a chain of) innermost for-loops, where it 65 // is most effective; we could generalize to more outer and while-loops. 66 unsigned redExp; 67 Value redVal; 68 // Current vector length and mask. 69 unsigned curVecLength; 70 Value curVecMask; 71 }; 72 73 } // namespace 74 75 // Helper method to apply dimension ordering permutation. 76 static unsigned perm(SparseTensorEncodingAttr &enc, unsigned d) { 77 if (enc) { 78 auto order = enc.getDimOrdering(); 79 if (order) { 80 assert(order.isPermutation()); 81 return order.getDimPosition(d); 82 } 83 } 84 return d; 85 } 86 87 // Helper method to translate dim level type to internal representation. 88 static Dim toDim(SparseTensorEncodingAttr &enc, unsigned d) { 89 if (enc) { 90 SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d]; 91 if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed) 92 return Dim::kSparse; 93 if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton) 94 return Dim::kSingle; 95 } 96 return Dim::kDense; 97 } 98 99 /// Helper method to inspect sparse encodings in the tensor types. 100 /// Fills the per-dimension sparsity information for all tensors. 101 static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) { 102 bool annotated = false; 103 for (OpOperand *t : op.getInputAndOutputOperands()) { 104 auto map = op.getTiedIndexingMap(t); 105 if (!map.isProjectedPermutation()) 106 return false; 107 auto enc = getSparseTensorEncoding(t->get().getType()); 108 if (enc) 109 annotated = true; 110 assert(map.getNumResults() == op.getRank(t)); 111 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 112 unsigned idx = map.getDimPosition(perm(enc, d)); 113 merger.setDim(t->getOperandNumber(), idx, toDim(enc, d)); 114 } 115 } 116 return annotated; 117 } 118 119 /// A DFS helper to compute a topological sort. Note that recursion is 120 /// bounded by the number of implicit loops, which is always small. 121 /// Returns false when a cycle is detected. 122 static bool topSortDFS(unsigned i, std::vector<unsigned> &visit, 123 std::vector<unsigned> &topSort, 124 std::vector<std::vector<bool>> &adjM) { 125 if (visit[i] != 0) 126 return visit[i] != 1; // 1 denotes cycle! 127 visit[i] = 1; 128 for (unsigned j = 0, e = visit.size(); j < e; j++) 129 if (adjM[i][j]) 130 if (!topSortDFS(j, visit, topSort, adjM)) 131 return false; 132 visit[i] = 2; 133 topSort.push_back(i); 134 return true; 135 } 136 137 /// Computes a topologically sorted iteration graph for the linalg operation. 138 /// Ensures all tensors are visited in natural index order. This is essential 139 /// for sparse storage formats since these only support access along fixed 140 /// dimensions. Even for dense storage formats, however, the natural index 141 /// order yields innermost unit-stride access with better spatial locality. 142 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op, 143 std::vector<unsigned> &topSort, 144 bool sparseOnly) { 145 // Set up an n x n from/to adjacency matrix of the iteration graph 146 // for the implicit loop indices i_0 .. i_n-1. 147 unsigned n = op.getNumLoops(); 148 std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false)); 149 150 // Iterate over the indexing maps of every tensor in the tensor expression. 151 for (OpOperand *t : op.getInputAndOutputOperands()) { 152 auto map = op.getTiedIndexingMap(t); 153 auto enc = getSparseTensorEncoding(t->get().getType()); 154 assert(map.getNumDims() == n); 155 // Skip dense tensor constraints when sparse only is requested. 156 if (sparseOnly && !enc) 157 continue; 158 // Each tensor expression and optional dimension ordering (row-major 159 // by default) puts an ordering constraint on the loop indices. For 160 // example, the tensor expresion A_ijk forces the ordering i < j < k 161 // on the loop indices if no explicit dimension ordering is given. 162 for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) { 163 unsigned f = map.getDimPosition(perm(enc, d - 1)); 164 unsigned t = map.getDimPosition(perm(enc, d)); 165 adjM[f][t] = true; 166 } 167 } 168 169 // Topologically sort the iteration graph to determine loop order. 170 // Report failure for a cyclic iteration graph. 171 topSort.clear(); 172 topSort.reserve(n); 173 std::vector<unsigned> visit(n, 0); 174 for (unsigned i = 0; i < n; i++) 175 if (visit[i] == 0) 176 if (!topSortDFS(i, visit, topSort, adjM)) 177 return false; // cycle! 178 std::reverse(std::begin(topSort), std::end(topSort)); 179 return true; 180 } 181 182 /// Returns true when the tensor expression is admissable for codegen. 183 /// Since all sparse input tensors are admissable, we just need to check 184 /// whether the output tensor in the tensor expression codegen is admissable. 185 static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op, 186 unsigned exp) { 187 OpOperand *lhs = op.getOutputOperand(0); 188 unsigned tensor = lhs->getOperandNumber(); 189 auto enc = getSparseTensorEncoding(lhs->get().getType()); 190 // An non-annotated output tensor is assumed dense, and becomes a random 191 // access n-dim memref. Admissable since inserstions cannot occur. 192 if (!enc) 193 return true; 194 // An all-dense annotated "sparse" output tensor becomes a linearized random 195 // access 1-dim memref. Also admissable since insertions cannot occur. 196 bool allDense = true; 197 unsigned numLoops = op.iterator_types().getValue().size(); 198 for (unsigned i = 0; i < numLoops; i++) 199 if (merger.isDim(tensor, i, Dim::kSparse)) { 200 allDense = false; 201 break; 202 } 203 if (allDense) 204 return true; 205 // A tensor expression with a sparse output tensor that changes its values 206 // but not its nonzero structure, an operation called "simply dynamic" in 207 // [Bik96,Ch9], is also admissable without special codegen. 208 if (merger.isConjunction(tensor, exp)) 209 return true; 210 // Reject for now since this requires changes to the nonzero structure. 211 // TODO: implement "workspaces" [Kjolstad2019] 212 return false; 213 } 214 215 /// Maps sparse integer option to actual integral storage type. 216 static Type genIntType(PatternRewriter &rewriter, unsigned width) { 217 if (width == 0) 218 return rewriter.getIndexType(); 219 return rewriter.getIntegerType(width); 220 } 221 222 /// Detects in-place annotation on tensor argument. 223 static bool getInPlace(Value val) { 224 if (auto arg = val.dyn_cast<BlockArgument>()) 225 if (auto funcOp = dyn_cast<FuncOp>(arg.getOwner()->getParentOp())) 226 if (auto attr = funcOp.getArgAttrOfType<BoolAttr>( 227 arg.getArgNumber(), linalg::LinalgDialect::kInplaceableAttrName)) 228 return attr.getValue(); 229 return false; 230 } 231 232 /// Generates buffer for the output tensor. 233 static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter, 234 linalg::GenericOp op, MemRefType denseTp, 235 ArrayRef<Value> args) { 236 Location loc = op.getLoc(); 237 Value tensor = op.getOutputOperand(0)->get(); 238 // The output tensor simply could materialize from the buffer that will 239 // be generated for the tensor present in the outs() clause. This has 240 // the major advantage that the sparse kernel only updates the nonzero 241 // positions for the output tensor. 242 if (getInPlace(tensor)) 243 return rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 244 // By default, a new buffer is allocated which is initialized to the 245 // tensor defined in the outs() clause. This is always correct but 246 // introduces a dense initialization component that may negatively 247 // impact the running complexity of the sparse kernel. 248 Value init = rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 249 Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args); 250 rewriter.create<memref::CopyOp>(loc, init, alloc); 251 return alloc; 252 } 253 254 /// Local bufferization of all dense and sparse data structures. 255 /// This code enables testing the first prototype sparse compiler. 256 // TODO: replace this with a proliferated bufferization strategy 257 static bool genBuffers(Merger &merger, CodeGen &codegen, 258 PatternRewriter &rewriter, linalg::GenericOp op) { 259 Location loc = op.getLoc(); 260 assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1); 261 // For every tensor, find lower and upper bound on dimensions, set the 262 // same bounds on loop indices, and obtain dense or sparse buffer(s). 263 SmallVector<Value, 4> args; 264 for (OpOperand *t : op.getInputAndOutputOperands()) { 265 unsigned tensor = t->getOperandNumber(); 266 auto shape = op.getShape(t); 267 auto map = op.getTiedIndexingMap(t); 268 auto enc = getSparseTensorEncoding(t->get().getType()); 269 // Scan all dimensions of current tensor. 270 args.clear(); 271 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 272 unsigned idx = map.getDimPosition(perm(enc, d)); 273 // Handle sparse storage schemes. 274 if (merger.isDim(tensor, idx, Dim::kSparse)) { 275 auto dynShape = {ShapedType::kDynamicSize}; 276 auto ptrTp = MemRefType::get( 277 dynShape, genIntType(rewriter, enc.getPointerBitWidth())); 278 auto indTp = MemRefType::get( 279 dynShape, genIntType(rewriter, enc.getIndexBitWidth())); 280 Value dim = rewriter.create<ConstantIndexOp>(loc, d); 281 // Generate sparse primitives to obtains pointer and indices. 282 codegen.pointers[tensor][idx] = 283 rewriter.create<ToPointersOp>(loc, ptrTp, t->get(), dim); 284 codegen.indices[tensor][idx] = 285 rewriter.create<ToIndicesOp>(loc, indTp, t->get(), dim); 286 } 287 // Find upper bound in current dimension. 288 unsigned p = perm(enc, d); 289 Value up = linalg::createOrFoldDimOp(rewriter, loc, t->get(), p); 290 if (shape[p] == MemRefType::kDynamicSize) 291 args.push_back(up); 292 assert(codegen.highs[tensor][idx] == nullptr); 293 codegen.sizes[idx] = codegen.highs[tensor][idx] = up; 294 } 295 // Perform the required bufferization. Dense inputs materialize 296 // from the input tensors. Dense outputs need special handling. 297 // Sparse inputs use sparse primitives to obtain the values. 298 // We also accept in-place all-dense annotated "sparse" outputs. 299 Type elementType = getElementTypeOrSelf(t->get().getType()); 300 if (!enc) { 301 // Non-annotated dense tensors. 302 auto denseTp = MemRefType::get(shape, elementType); 303 if (tensor < op.getNumInputs()) 304 codegen.buffers[tensor] = 305 rewriter.create<memref::BufferCastOp>(loc, denseTp, t->get()); 306 else 307 codegen.buffers[tensor] = 308 genOutputBuffer(codegen, rewriter, op, denseTp, args); 309 } else { 310 // Annotated sparse tensors. 311 if (tensor == op.getNumInputs() && !getInPlace(t->get())) 312 return false; // reject output if not in-place 313 auto dynShape = {ShapedType::kDynamicSize}; 314 auto sparseTp = MemRefType::get(dynShape, elementType); 315 codegen.buffers[tensor] = 316 rewriter.create<ToValuesOp>(loc, sparseTp, t->get()); 317 } 318 } 319 return true; 320 } 321 322 /// Constructs vector type. 323 static VectorType vectorType(CodeGen &codegen, Type etp) { 324 return VectorType::get(codegen.curVecLength, etp); 325 } 326 327 /// Constructs vector type from pointer. 328 static VectorType vectorType(CodeGen &codegen, Value ptr) { 329 return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType()); 330 } 331 332 /// Constructs vector iteration mask. 333 static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter, 334 Value iv, Value lo, Value hi, Value step) { 335 Location loc = iv.getLoc(); 336 VectorType mtp = vectorType(codegen, rewriter.getIntegerType(1)); 337 // Special case if the vector length evenly divides the trip count (for 338 // example, "for i = 0, 128, 16"). A constant all-true mask is generated 339 // so that all subsequent masked memory operations are immediately folded 340 // into unconditional memory operations. 341 IntegerAttr loInt, hiInt, stepInt; 342 if (matchPattern(lo, m_Constant(&loInt)) && 343 matchPattern(hi, m_Constant(&hiInt)) && 344 matchPattern(step, m_Constant(&stepInt))) { 345 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) 346 return rewriter.create<vector::BroadcastOp>( 347 loc, mtp, rewriter.create<ConstantIntOp>(loc, 1, 1)); 348 } 349 // Otherwise, generate a vector mask that avoids overrunning the upperbound 350 // during vector execution. Here we rely on subsequent loop optimizations to 351 // avoid executing the mask in all iterations, for example, by splitting the 352 // loop into an unconditional vector loop and a scalar cleanup loop. 353 auto minMap = AffineMap::get( 354 /*dimCount=*/2, /*symbolCount=*/1, 355 {rewriter.getAffineSymbolExpr(0), 356 rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)}, 357 rewriter.getContext()); 358 Value end = 359 rewriter.createOrFold<AffineMinOp>(loc, minMap, ValueRange{hi, iv, step}); 360 return rewriter.create<vector::CreateMaskOp>(loc, mtp, end); 361 } 362 363 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi]. 364 static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter, 365 Value ptr, ArrayRef<Value> args) { 366 Location loc = ptr.getLoc(); 367 VectorType vtp = vectorType(codegen, ptr); 368 Value pass = rewriter.create<ConstantOp>(loc, vtp, rewriter.getZeroAttr(vtp)); 369 if (args.back().getType().isa<VectorType>()) { 370 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 371 Value indexVec = args.back(); 372 scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0); 373 return rewriter.create<vector::GatherOp>( 374 loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass); 375 } 376 return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args, 377 codegen.curVecMask, pass); 378 } 379 380 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs. 381 static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter, 382 Value rhs, Value ptr, ArrayRef<Value> args) { 383 Location loc = ptr.getLoc(); 384 if (args.back().getType().isa<VectorType>()) { 385 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 386 Value indexVec = args.back(); 387 scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0); 388 rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, 389 codegen.curVecMask, rhs); 390 return; 391 } 392 rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask, 393 rhs); 394 } 395 396 /// Generates a vectorized invariant. Here we rely on subsequent loop 397 /// optimizations to hoist the invariant broadcast out of the vector loop. 398 static Value genVectorInvariantValue(CodeGen &codegen, 399 PatternRewriter &rewriter, Value val) { 400 VectorType vtp = vectorType(codegen, val.getType()); 401 return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val); 402 } 403 404 /// Generates a load on a dense or sparse tensor. 405 static Value genTensorLoad(Merger &merger, CodeGen &codegen, 406 PatternRewriter &rewriter, linalg::GenericOp op, 407 unsigned exp) { 408 // Test if the load was hoisted to a higher loop nest. 409 Value val = merger.exp(exp).val; 410 if (val) { 411 if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>()) 412 return genVectorInvariantValue(codegen, rewriter, val); 413 return val; 414 } 415 // Actual load. 416 SmallVector<Value, 4> args; 417 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 418 unsigned tensor = t->getOperandNumber(); 419 auto map = op.getTiedIndexingMap(t); 420 auto enc = getSparseTensorEncoding(t->get().getType()); 421 unsigned rank = map.getNumResults(); 422 if (enc) { 423 unsigned idx = map.getDimPosition(perm(enc, rank - 1)); 424 assert(codegen.pidxs[tensor][idx] != nullptr); 425 args.push_back(codegen.pidxs[tensor][idx]); // position index 426 } else { 427 for (unsigned d = 0; d < rank; d++) { 428 unsigned idx = map.getDimPosition(d); 429 args.push_back(codegen.loops[idx]); // universal dense index 430 } 431 } 432 Location loc = op.getLoc(); 433 Value ptr = codegen.buffers[tensor]; 434 if (codegen.curVecLength > 1) 435 return genVectorLoad(codegen, rewriter, ptr, args); 436 return rewriter.create<memref::LoadOp>(loc, ptr, args); 437 } 438 439 /// Generates a store on a dense or sparse tensor. 440 static void genTensorStore(Merger &merger, CodeGen &codegen, 441 PatternRewriter &rewriter, linalg::GenericOp op, 442 OpOperand *t, Value rhs) { 443 Location loc = op.getLoc(); 444 // Test if this is a scalarized reduction. 445 OpOperand *lhs = op.getOutputOperand(0); 446 if (lhs == t && codegen.redVal) { 447 if (codegen.curVecLength > 1) 448 rhs = rewriter.create<SelectOp>(loc, codegen.curVecMask, rhs, 449 codegen.redVal); 450 codegen.redVal = rhs; 451 return; 452 } 453 // Actual store. 454 SmallVector<Value, 4> args; 455 unsigned tensor = t->getOperandNumber(); 456 auto map = op.getTiedIndexingMap(t); 457 auto enc = getSparseTensorEncoding(t->get().getType()); 458 unsigned rank = map.getNumResults(); 459 if (enc) { 460 unsigned idx = map.getDimPosition(perm(enc, rank - 1)); 461 assert(codegen.pidxs[tensor][idx] != nullptr); 462 args.push_back(codegen.pidxs[tensor][idx]); // position index 463 } else { 464 for (unsigned d = 0; d < rank; d++) { 465 unsigned idx = map.getDimPosition(d); 466 args.push_back(codegen.loops[idx]); // universal dense index 467 } 468 } 469 Value ptr = codegen.buffers[tensor]; 470 if (codegen.curVecLength > 1) 471 genVectorStore(codegen, rewriter, rhs, ptr, args); 472 else 473 rewriter.create<memref::StoreOp>(loc, rhs, ptr, args); 474 } 475 476 /// Generates a pointer/index load from the sparse storage scheme. Narrower 477 /// data types need to be zero extended before casting the value into the 478 /// index type used for looping and indexing. 479 static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc, 480 Value ptr, Value s) { 481 // See https://llvm.org/docs/GetElementPtr.html for some background on 482 // the complications described below. 483 if (codegen.curVecLength > 1) { 484 // Since the index vector is used in a subsequent gather/scatter operations, 485 // which effectively defines an unsigned pointer + signed index, we must 486 // zero extend the vector to an index width. For 8-bit and 16-bit values, 487 // an 32-bit index width suffices. For 32-bit values, zero extending the 488 // elements into 64-bit loses some performance since the 32-bit indexed 489 // gather/scatter is more efficient than the 64-bit index variant (if the 490 // negative 32-bit index space is unused, the enableSIMDIndex32 flag can 491 // preserve this performance). For 64-bit values, there is no good way 492 // to state that the indices are unsigned, with creates the potential of 493 // incorrect address calculations in the unlikely case we need such 494 // extremely large offsets. 495 Type etp = ptr.getType().cast<MemRefType>().getElementType(); 496 Value vload = genVectorLoad(codegen, rewriter, ptr, {s}); 497 if (!etp.isa<IndexType>()) { 498 if (etp.getIntOrFloatBitWidth() < 32) 499 vload = rewriter.create<ZeroExtendIOp>( 500 loc, vload, vectorType(codegen, rewriter.getIntegerType(32))); 501 else if (etp.getIntOrFloatBitWidth() < 64 && 502 !codegen.options.enableSIMDIndex32) 503 vload = rewriter.create<ZeroExtendIOp>( 504 loc, vload, vectorType(codegen, rewriter.getIntegerType(64))); 505 } 506 return vload; 507 } 508 // For the scalar case, we simply zero extend narrower indices into 64-bit 509 // values before casting to index without a performance penalty. Here too, 510 // however, indices that already are 64-bit, in theory, cannot express the 511 // full range as explained above. 512 Value load = rewriter.create<memref::LoadOp>(loc, ptr, s); 513 if (!load.getType().isa<IndexType>()) { 514 if (load.getType().getIntOrFloatBitWidth() < 64) 515 load = rewriter.create<ZeroExtendIOp>(loc, load, 516 rewriter.getIntegerType(64)); 517 load = rewriter.create<IndexCastOp>(loc, load, rewriter.getIndexType()); 518 } 519 return load; 520 } 521 522 /// Generates an invariant value. 523 static Value genInvariantValue(Merger &merger, CodeGen &codegen, 524 PatternRewriter &rewriter, unsigned exp) { 525 Value val = merger.exp(exp).val; 526 if (codegen.curVecLength > 1) 527 return genVectorInvariantValue(codegen, rewriter, val); 528 return val; 529 } 530 531 /// Generates an address computation "sz * p + i". 532 static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter, 533 Location loc, Value size, Value p, Value i) { 534 Value mul = rewriter.create<MulIOp>(loc, size, p); 535 if (auto vtp = i.getType().dyn_cast<VectorType>()) { 536 Value inv = rewriter.create<IndexCastOp>(loc, mul, vtp.getElementType()); 537 mul = genVectorInvariantValue(codegen, rewriter, inv); 538 } 539 return rewriter.create<AddIOp>(loc, mul, i); 540 } 541 542 /// Generates start of a reduction. 543 static Value genReductionStart(Merger &merger, CodeGen &codegen, 544 PatternRewriter &rewriter, 545 linalg::GenericOp op) { 546 if (codegen.redVal) 547 return codegen.redVal; // chained with previous for-loop 548 if (codegen.curVecLength > 1) { 549 // TODO: assumes + reductions for now 550 VectorType vtp = vectorType(codegen, codegen.buffers[codegen.redExp]); 551 return rewriter.create<ConstantOp>(op.getLoc(), vtp, 552 rewriter.getZeroAttr(vtp)); 553 } 554 return genTensorLoad(merger, codegen, rewriter, op, codegen.redExp); 555 } 556 557 /// Generates end of a reduction. 558 static void genReductionEnd(Merger &merger, CodeGen &codegen, 559 PatternRewriter &rewriter, linalg::GenericOp op) { 560 Value red = codegen.redVal; 561 if (!red) 562 return; 563 assert(codegen.curVecLength == 1); 564 codegen.redVal = merger.exp(codegen.redExp).val = Value(); // end chain 565 OpOperand *lhs = op.getOutputOperand(0); 566 if (auto vtp = red.getType().dyn_cast<VectorType>()) { 567 // TODO: assumes + reductions for now 568 StringAttr kind = rewriter.getStringAttr("add"); 569 Value ld = genTensorLoad(merger, codegen, rewriter, op, codegen.redExp); 570 // Integer reductions don't accept an accumulator. 571 if (vtp.getElementType().isa<IntegerType>()) { 572 red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(), 573 kind, red, ValueRange{}); 574 red = rewriter.create<AddIOp>(op.getLoc(), red, ld); 575 } else { 576 red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(), 577 kind, red, ld); 578 } 579 } 580 genTensorStore(merger, codegen, rewriter, op, lhs, red); 581 } 582 583 /// Recursively generates tensor expression. 584 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 585 linalg::GenericOp op, unsigned exp) { 586 Location loc = op.getLoc(); 587 if (exp == -1u) 588 return Value(); 589 if (merger.exp(exp).kind == Kind::kTensor) 590 return genTensorLoad(merger, codegen, rewriter, op, exp); 591 if (merger.exp(exp).kind == Kind::kInvariant) 592 return genInvariantValue(merger, codegen, rewriter, exp); 593 Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0); 594 Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1); 595 if (merger.exp(exp).kind == Kind::kNegI) { 596 // TODO: no negi in std, need to make zero explicit. 597 Type tp = op.getOutputTensorTypes()[0].getElementType(); 598 v1 = v0; 599 v0 = rewriter.create<ConstantOp>(loc, tp, rewriter.getZeroAttr(tp)); 600 if (codegen.curVecLength > 1) 601 v0 = genVectorInvariantValue(codegen, rewriter, v0); 602 } 603 return merger.buildExp(rewriter, loc, exp, v0, v1); 604 } 605 606 /// Hoists loop invariant tensor loads for which indices have been exhausted. 607 static void genInvariants(Merger &merger, CodeGen &codegen, 608 PatternRewriter &rewriter, linalg::GenericOp op, 609 unsigned exp, unsigned ldx, bool hoist) { 610 if (exp == -1u) 611 return; 612 if (merger.exp(exp).kind == Kind::kTensor) { 613 // Inspect tensor indices. 614 bool atLevel = ldx == -1u; 615 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 616 auto map = op.getTiedIndexingMap(t); 617 auto enc = getSparseTensorEncoding(t->get().getType()); 618 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 619 unsigned idx = map.getDimPosition(perm(enc, d)); 620 if (!codegen.loops[idx]) 621 return; // still in play 622 else if (idx == ldx) 623 atLevel = true; 624 } 625 // All exhausted at this level (atLevel denotes exactly at this level). 626 OpOperand *lhs = op.getOutputOperand(0); 627 if (lhs == t) { 628 codegen.redExp = hoist ? exp : -1u; 629 } else if (atLevel) { 630 merger.exp(exp).val = 631 hoist ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value(); 632 } 633 } else if (merger.exp(exp).kind != Kind::kInvariant) { 634 // Traverse into the binary operations. Note that we only hoist 635 // tensor loads, since subsequent MLIR/LLVM passes know how to 636 // deal with all other kinds of derived loop invariants. 637 unsigned e0 = merger.exp(exp).children.e0; 638 unsigned e1 = merger.exp(exp).children.e1; 639 genInvariants(merger, codegen, rewriter, op, e0, ldx, hoist); 640 genInvariants(merger, codegen, rewriter, op, e1, ldx, hoist); 641 } 642 } 643 644 /// Generates initialization code for the subsequent loop sequence at 645 /// current index level. Returns true if the loop sequence needs to 646 /// maintain the universal index. 647 static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 648 linalg::GenericOp op, std::vector<unsigned> &topSort, 649 unsigned at, llvm::BitVector &inits) { 650 bool needsUniv = false; 651 Location loc = op.getLoc(); 652 unsigned idx = topSort[at]; 653 654 // Initialize sparse positions. 655 for (unsigned b = 0, be = inits.size(); b < be; b++) { 656 if (inits[b]) { 657 unsigned tensor = merger.tensor(b); 658 assert(idx == merger.index(b)); 659 if (merger.isDim(b, Dim::kSparse)) { 660 // Initialize sparse index. 661 unsigned pat = at; 662 for (; pat != 0; pat--) { 663 if (codegen.pidxs[tensor][topSort[pat - 1]]) 664 break; 665 } 666 Value ptr = codegen.pointers[tensor][idx]; 667 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 668 Value p0 = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0) 669 : codegen.pidxs[tensor][topSort[pat - 1]]; 670 codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0); 671 Value p1 = rewriter.create<AddIOp>(loc, p0, one); 672 codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1); 673 } else { 674 // Dense index still in play. 675 needsUniv = true; 676 } 677 } 678 } 679 680 // Initialize the universal dense index. 681 codegen.loops[idx] = rewriter.create<ConstantIndexOp>(loc, 0); 682 return needsUniv; 683 } 684 685 /// Returns vectorization strategy. Any implicit inner loop in the Linalg 686 /// operation is a candidate. Whether it is actually converted to SIMD code 687 /// depends on the requested strategy. 688 static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) { 689 switch (codegen.options.vectorizationStrategy) { 690 case SparseVectorizationStrategy::kNone: 691 return false; 692 case SparseVectorizationStrategy::kDenseInnerLoop: 693 return isInner && !isSparse; 694 case SparseVectorizationStrategy::kAnyStorageInnerLoop: 695 return isInner; 696 } 697 llvm_unreachable("unexpected vectorization strategy"); 698 } 699 700 /// Returns parallelization strategy. Any implicit loop in the Linalg operation 701 /// that is marked "parallel" is a candidate. Whether it is actually converted 702 /// to a parallel operation depends on the requested strategy. 703 static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction, 704 bool isSparse, bool isVector) { 705 switch (codegen.options.parallelizationStrategy) { 706 case SparseParallelizationStrategy::kNone: 707 return false; 708 case SparseParallelizationStrategy::kDenseOuterLoop: 709 return isOuter && !isSparse && !isReduction && !isVector; 710 case SparseParallelizationStrategy::kAnyStorageOuterLoop: 711 return isOuter && !isReduction && !isVector; 712 case SparseParallelizationStrategy::kDenseAnyLoop: 713 return !isSparse && !isReduction && !isVector; 714 case SparseParallelizationStrategy::kAnyStorageAnyLoop: 715 return !isReduction && !isVector; 716 } 717 llvm_unreachable("unexpected parallelization strategy"); 718 } 719 720 /// Checks unit strides for dense tensors. The iteration graph may have ignored 721 /// dense access patterns in order to avoid cycles (sparse access patterns are 722 /// always placed innermost), but that means dense access has become strided. 723 /// For now, we reject vectorization of such cases. 724 /// TODO: implement strided load/stores on dense arrays 725 static bool denseUnitStrides(Merger &merger, linalg::GenericOp op, 726 unsigned idx) { 727 for (OpOperand *t : op.getInputAndOutputOperands()) { 728 if (!getSparseTensorEncoding(t->get().getType())) { 729 auto map = op.getTiedIndexingMap(t); 730 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 731 if (map.getDimPosition(d) == idx && d != rank - 1) 732 return false; 733 } 734 } 735 } 736 return true; 737 } 738 739 /// Generates a for-loop on a single index. 740 static Operation *genFor(Merger &merger, CodeGen &codegen, 741 PatternRewriter &rewriter, linalg::GenericOp op, 742 bool isOuter, bool isInner, unsigned idx, 743 llvm::BitVector &indices) { 744 unsigned fb = indices.find_first(); 745 unsigned tensor = merger.tensor(fb); 746 assert(idx == merger.index(fb)); 747 auto iteratorTypes = op.iterator_types().getValue(); 748 bool isReduction = isReductionIterator(iteratorTypes[idx]); 749 bool isSparse = merger.isDim(fb, Dim::kSparse); 750 bool isVector = isVectorFor(codegen, isInner, isSparse) && 751 denseUnitStrides(merger, op, idx); 752 bool isParallel = 753 isParallelFor(codegen, isOuter, isReduction, isSparse, isVector); 754 755 // Prepare vector length. 756 if (isVector) 757 codegen.curVecLength = codegen.options.vectorLength; 758 759 // Loop bounds and increment. 760 Location loc = op.getLoc(); 761 Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx]; 762 Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx]; 763 Value step = rewriter.create<ConstantIndexOp>(loc, codegen.curVecLength); 764 765 // Emit a parallel loop. 766 if (isParallel) { 767 assert(!isVector); 768 scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step); 769 if (isSparse) 770 codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0]; 771 else 772 codegen.loops[idx] = parOp.getInductionVars()[0]; 773 rewriter.setInsertionPointToStart(parOp.getBody()); 774 return parOp; 775 } 776 777 // Emit a sequential loop, potentially with a scalarized reduction. 778 bool scalarRed = isInner && codegen.redExp != -1u; 779 SmallVector<Value, 4> operands; 780 if (scalarRed) { 781 Value load = genReductionStart(merger, codegen, rewriter, op); 782 operands.push_back(load); 783 } 784 scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands); 785 if (scalarRed) { 786 codegen.redVal = merger.exp(codegen.redExp).val = 787 forOp.getRegionIterArgs().front(); 788 } 789 // Assign induction variable to sparse or dense index. 790 Value iv = forOp.getInductionVar(); 791 if (isSparse) 792 codegen.pidxs[tensor][idx] = iv; 793 else 794 codegen.loops[idx] = iv; 795 rewriter.setInsertionPointToStart(forOp.getBody()); 796 // Share vector iteration mask between all subsequent loads/stores. 797 if (isVector) 798 codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step); 799 return forOp; 800 } 801 802 /// Emit a while-loop for co-iteration over multiple indices. 803 static Operation *genWhile(Merger &merger, CodeGen &codegen, 804 PatternRewriter &rewriter, linalg::GenericOp op, 805 unsigned idx, bool needsUniv, 806 llvm::BitVector &indices) { 807 SmallVector<Type, 4> types; 808 SmallVector<Value, 4> operands; 809 // Construct the while-loop with a parameter for each index. 810 Type indexType = rewriter.getIndexType(); 811 for (unsigned b = 0, be = indices.size(); b < be; b++) { 812 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 813 unsigned tensor = merger.tensor(b); 814 assert(idx == merger.index(b)); 815 types.push_back(indexType); 816 assert(codegen.pidxs[tensor][idx].getType().isa<IndexType>() && 817 "type mismatch for sparse index"); 818 operands.push_back(codegen.pidxs[tensor][idx]); 819 } 820 } 821 if (needsUniv) { 822 types.push_back(indexType); 823 assert(codegen.loops[idx].getType().isa<IndexType>() && 824 "type mismatch for universal index"); 825 operands.push_back(codegen.loops[idx]); 826 } 827 Location loc = op.getLoc(); 828 scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands); 829 Block *before = rewriter.createBlock(&whileOp.before(), {}, types); 830 Block *after = rewriter.createBlock(&whileOp.after(), {}, types); 831 832 // Build the "before" region, which effectively consists 833 // of a conjunction of "i < upper" tests on all induction. 834 rewriter.setInsertionPointToStart(&whileOp.before().front()); 835 Value cond; 836 unsigned o = 0; 837 for (unsigned b = 0, be = indices.size(); b < be; b++) { 838 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 839 unsigned tensor = merger.tensor(b); 840 assert(idx == merger.index(b)); 841 Value op1 = before->getArgument(o); 842 Value op2 = codegen.highs[tensor][idx]; 843 Value opc = rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, op1, op2); 844 cond = cond ? rewriter.create<AndOp>(loc, cond, opc) : opc; 845 codegen.pidxs[tensor][idx] = after->getArgument(o++); 846 } 847 } 848 if (needsUniv) 849 codegen.loops[idx] = after->getArgument(o++); 850 assert(o == operands.size()); 851 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 852 rewriter.setInsertionPointToStart(&whileOp.after().front()); 853 return whileOp; 854 } 855 856 /// Generates a for-loop or a while-loop, depending on whether it implements 857 /// singleton iteration or co-iteration over the given conjunction. 858 static Operation *genLoop(Merger &merger, CodeGen &codegen, 859 PatternRewriter &rewriter, linalg::GenericOp op, 860 std::vector<unsigned> &topSort, unsigned at, 861 bool needsUniv, llvm::BitVector &indices) { 862 unsigned idx = topSort[at]; 863 if (indices.count() == 1) { 864 bool isOuter = at == 0; 865 bool isInner = at == topSort.size() - 1; 866 return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx, 867 indices); 868 } 869 genReductionEnd(merger, codegen, rewriter, op); // cannot chain 870 return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices); 871 } 872 873 /// Generates the local variables for this loop, consisting of the sparse 874 /// indices, restored universal dense index, and dense positions. 875 static void genLocals(Merger &merger, CodeGen &codegen, 876 PatternRewriter &rewriter, linalg::GenericOp op, 877 std::vector<unsigned> &topSort, unsigned at, 878 bool needsUniv, llvm::BitVector &locals) { 879 Location loc = op.getLoc(); 880 unsigned idx = topSort[at]; 881 882 // Initialize sparse indices. 883 Value min; 884 for (unsigned b = 0, be = locals.size(); b < be; b++) { 885 if (locals[b] && merger.isDim(b, Dim::kSparse)) { 886 unsigned tensor = merger.tensor(b); 887 assert(idx == merger.index(b)); 888 Value ptr = codegen.indices[tensor][idx]; 889 Value s = codegen.pidxs[tensor][idx]; 890 Value load = genLoad(codegen, rewriter, loc, ptr, s); 891 codegen.idxs[tensor][idx] = load; 892 if (!needsUniv) { 893 if (min) { 894 Value cmp = 895 rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, load, min); 896 min = rewriter.create<SelectOp>(loc, cmp, load, min); 897 } else { 898 min = load; 899 } 900 } 901 } 902 } 903 904 // Merge dense universal index over minimum. 905 if (min) { 906 assert(!needsUniv); 907 codegen.loops[idx] = min; 908 } 909 910 // Initialize dense positions. Note that we generate dense indices of the 911 // output tensor unconditionally, since they may not appear in the lattice, 912 // but may be needed for linearized codegen. 913 for (unsigned b = 0, be = locals.size(); b < be; b++) { 914 if ((locals[b] || merger.isOutTensor(b, idx)) && 915 merger.isDim(b, Dim::kDense)) { 916 unsigned tensor = merger.tensor(b); 917 assert(idx == merger.index(b)); 918 unsigned pat = at; 919 for (; pat != 0; pat--) 920 if (codegen.pidxs[tensor][topSort[pat - 1]]) 921 break; 922 Value p = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0) 923 : codegen.pidxs[tensor][topSort[pat - 1]]; 924 codegen.pidxs[tensor][idx] = genAddress( 925 codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]); 926 } 927 } 928 } 929 930 /// Generates the induction structure for a while-loop. 931 static void genWhileInduction(Merger &merger, CodeGen &codegen, 932 PatternRewriter &rewriter, linalg::GenericOp op, 933 unsigned idx, bool needsUniv, 934 llvm::BitVector &induction, ResultRange results) { 935 Location loc = op.getLoc(); 936 unsigned o = 0; 937 SmallVector<Value, 4> operands; 938 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 939 for (unsigned b = 0, be = induction.size(); b < be; b++) { 940 if (induction[b] && merger.isDim(b, Dim::kSparse)) { 941 unsigned tensor = merger.tensor(b); 942 assert(idx == merger.index(b)); 943 Value op1 = codegen.idxs[tensor][idx]; 944 Value op2 = codegen.loops[idx]; 945 Value op3 = codegen.pidxs[tensor][idx]; 946 Value cmp = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2); 947 Value add = rewriter.create<AddIOp>(loc, op3, one); 948 operands.push_back(rewriter.create<SelectOp>(loc, cmp, add, op3)); 949 codegen.pidxs[tensor][idx] = results[o++]; 950 } 951 } 952 if (needsUniv) { 953 operands.push_back(rewriter.create<AddIOp>(loc, codegen.loops[idx], one)); 954 codegen.loops[idx] = results[o++]; 955 } 956 assert(o == operands.size()); 957 rewriter.create<scf::YieldOp>(loc, operands); 958 } 959 960 /// Generates a single if-statement within a while-loop. 961 static scf::IfOp genIf(Merger &merger, CodeGen &codegen, 962 PatternRewriter &rewriter, linalg::GenericOp op, 963 unsigned idx, llvm::BitVector &conditions) { 964 Location loc = op.getLoc(); 965 Value cond; 966 for (unsigned b = 0, be = conditions.size(); b < be; b++) { 967 if (conditions[b]) { 968 unsigned tensor = merger.tensor(b); 969 assert(idx == merger.index(b)); 970 Value clause; 971 if (merger.isDim(b, Dim::kSparse)) { 972 Value op1 = codegen.idxs[tensor][idx]; 973 Value op2 = codegen.loops[idx]; 974 clause = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2); 975 } else { 976 clause = rewriter.create<ConstantIntOp>(loc, 1, 1); // true 977 } 978 cond = cond ? rewriter.create<AndOp>(loc, cond, clause) : clause; 979 } 980 } 981 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ true); 982 rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); 983 return ifOp; 984 } 985 986 /// Recursively generates code while computing iteration lattices in order 987 /// to manage the complexity of implementing co-iteration over unions 988 /// and intersections of sparse iterations spaces. 989 static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 990 linalg::GenericOp op, std::vector<unsigned> &topSort, 991 unsigned exp, unsigned at) { 992 // At each leaf, assign remaining tensor (sub)expression to output tensor. 993 if (at == topSort.size()) { 994 OpOperand *lhs = op.getOutputOperand(0); 995 Value rhs = genExp(merger, codegen, rewriter, op, exp); 996 genTensorStore(merger, codegen, rewriter, op, lhs, rhs); 997 return; 998 } 999 assert(codegen.curVecLength == 1); 1000 1001 // Construct iteration lattices for current loop index, with L0 at top. 1002 // Then emit initialization code for the loop sequence at this level. 1003 // We maintain the universal dense index if dense indices are still 1004 // in play for a non-singleton loop sequence. 1005 Location loc = op.getLoc(); 1006 unsigned idx = topSort[at]; 1007 unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx)); 1008 unsigned lsize = merger.set(lts).size(); 1009 assert(lsize != 0); 1010 unsigned l0 = merger.set(lts)[0]; 1011 unsigned ldx = at == 0 ? -1u : topSort[at - 1]; 1012 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/true); 1013 bool needsUniv = false; 1014 if (genInit(merger, codegen, rewriter, op, topSort, at, 1015 merger.lat(l0).bits)) { 1016 // Maintain the universal index only if it is actually 1017 // consumed by a subsequent lattice point. 1018 for (unsigned i = 1; i < lsize; i++) { 1019 unsigned li = merger.set(lts)[i]; 1020 if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) { 1021 needsUniv = true; 1022 break; 1023 } 1024 } 1025 } 1026 1027 // Emit a loop for every lattice point L0 >= Li. 1028 for (unsigned i = 0; i < lsize; i++) { 1029 unsigned li = merger.set(lts)[i]; 1030 1031 // Emit loop. 1032 codegen.curVecLength = 1; 1033 llvm::BitVector indices = merger.lat(li).simple; 1034 Operation *loop = 1035 genLoop(merger, codegen, rewriter, op, topSort, at, needsUniv, indices); 1036 genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, 1037 merger.lat(li).bits); 1038 1039 // Visit all lattices points with Li >= Lj to generate the 1040 // loop-body, possibly with if statements for coiteration. 1041 bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr; 1042 for (unsigned j = 0; j < lsize; j++) { 1043 unsigned lj = merger.set(lts)[j]; 1044 unsigned ej = merger.lat(lj).exp; 1045 if (li == lj || merger.latGT(li, lj)) { 1046 // Recurse into body of each branch. 1047 if (isWhile) { 1048 scf::IfOp ifOp = 1049 genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple); 1050 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1051 rewriter.setInsertionPointToStart(&ifOp.elseRegion().front()); 1052 } else { 1053 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1054 } 1055 } 1056 } 1057 1058 // Wrap-up induction and restore insertion point. 1059 if (isWhile) { 1060 scf::WhileOp whileOp = cast<scf::WhileOp>(loop); 1061 rewriter.setInsertionPointToEnd(&whileOp.after().front()); 1062 genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv, 1063 merger.lat(li).bits, whileOp.results()); 1064 } else { 1065 needsUniv = false; 1066 if (codegen.redVal) { 1067 rewriter.create<scf::YieldOp>(loc, codegen.redVal); 1068 codegen.redVal = loop->getResult(0); 1069 } 1070 } 1071 rewriter.setInsertionPointAfter(loop); 1072 } 1073 1074 // Wrap-up loop sequence. 1075 codegen.curVecLength = 1; 1076 genReductionEnd(merger, codegen, rewriter, op); 1077 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/false); 1078 codegen.loops[idx] = Value(); 1079 } 1080 1081 /// Converts the result computed by the sparse kernel into the required form. 1082 static void genResult(Merger &merger, CodeGen &codegen, 1083 PatternRewriter &rewriter, linalg::GenericOp op) { 1084 Location loc = op.getLoc(); 1085 OpOperand *lhs = op.getOutputOperand(0); 1086 Type resType = lhs->get().getType(); 1087 unsigned tensor = lhs->getOperandNumber(); 1088 auto map = op.getTiedIndexingMap(lhs); 1089 auto enc = getSparseTensorEncoding(resType); 1090 Value result = codegen.buffers.back(); // value array 1091 if (enc) { 1092 // The sparse annotation unambigiously defines the arrays needed 1093 // to "reconstruct" the sparse tensor from the storage scheme 1094 // (even though lowering should never need this eventually). 1095 SmallVector<Value, 4> args; 1096 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 1097 unsigned idx = map.getDimPosition(perm(enc, d)); 1098 if (merger.isDim(tensor, idx, Dim::kSparse)) { 1099 args.push_back(codegen.pointers[tensor][idx]); 1100 args.push_back(codegen.indices[tensor][idx]); 1101 } 1102 } 1103 args.push_back(result); 1104 result = rewriter.create<ToTensorOp>(loc, resType, args); 1105 } else { 1106 // To "reconstruct" an non-annotated tensor, sipmly load it 1107 // from the bufferized value. 1108 result = rewriter.create<memref::TensorLoadOp>(loc, resType, result); 1109 } 1110 rewriter.replaceOp(op, result); 1111 } 1112 1113 namespace { 1114 1115 /// Sparse rewriting rule for generic Lingalg operation. 1116 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> { 1117 public: 1118 GenericOpSparsifier(MLIRContext *context, SparsificationOptions o) 1119 : OpRewritePattern<linalg::GenericOp>(context), options(o) {} 1120 1121 LogicalResult matchAndRewrite(linalg::GenericOp op, 1122 PatternRewriter &rewriter) const override { 1123 // Detects sparse annotations and translate the per-dimension sparsity 1124 // information for all tensors to loop indices in the kernel. 1125 assert(op.getNumOutputs() == 1); 1126 unsigned numTensors = op.getNumInputsAndOutputs(); 1127 unsigned numLoops = op.iterator_types().getValue().size(); 1128 Merger merger(numTensors, numLoops); 1129 if (!findSparseAnnotations(merger, op)) 1130 return failure(); 1131 1132 // Computes a topologically sorted iteration graph to ensure 1133 // tensors are visited in natural index order. Fails on cycles. 1134 // This assumes that higher-level passes have already put the 1135 // tensors in each tensor expression in a feasible order. 1136 std::vector<unsigned> topSort; 1137 if (!computeIterationGraph(merger, op, topSort, /*sparseOnly=*/false) && 1138 !computeIterationGraph(merger, op, topSort, /*sparseOnly=*/true)) 1139 return failure(); 1140 1141 // Builds the tensor expression for the Linalg operation in SSA form. 1142 Optional<unsigned> exp = merger.buildTensorExpFromLinalg(op); 1143 if (!exp.hasValue()) 1144 return failure(); 1145 1146 // Rejects an inadmissable tensor expression. 1147 if (!isAdmissableTensorExp(merger, op, exp.getValue())) 1148 return failure(); 1149 1150 // Recursively generates code. 1151 CodeGen codegen(options, numTensors, numLoops); 1152 if (!genBuffers(merger, codegen, rewriter, op)) 1153 return failure(); // could not bufferize 1154 genStmt(merger, codegen, rewriter, op, topSort, exp.getValue(), 0); 1155 genResult(merger, codegen, rewriter, op); 1156 return success(); 1157 } 1158 1159 private: 1160 /// Options to control sparse code generation. 1161 SparsificationOptions options; 1162 }; 1163 1164 } // namespace 1165 1166 /// Populates the given patterns list with rewriting rules required for 1167 /// the sparsification of linear algebra operations. 1168 void mlir::populateSparsificationPatterns( 1169 RewritePatternSet &patterns, const SparsificationOptions &options) { 1170 patterns.add<GenericOpSparsifier>(patterns.getContext(), options); 1171 } 1172