1 //===- Sparsification.cpp - Implementation of sparsification --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements converting sparse tensor types to actual sparse code. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 15 #include "mlir/Dialect/Linalg/Utils/Utils.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/SCF/SCF.h" 18 #include "mlir/Dialect/SCF/Transforms.h" 19 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 20 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 21 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 22 #include "mlir/Dialect/StandardOps/IR/Ops.h" 23 #include "mlir/Dialect/Vector/VectorOps.h" 24 #include "mlir/IR/Matchers.h" 25 #include "mlir/IR/TensorEncoding.h" 26 #include "llvm/ADT/SmallBitVector.h" 27 28 using namespace mlir; 29 using namespace mlir::sparse_tensor; 30 31 namespace { 32 33 // Iteration graph sorting. 34 enum SortMask { kSparseOnly = 0x0, kIncludeDense = 0x1, kIncludeUndef = 0x2 }; 35 36 // Code generation. 37 struct CodeGen { 38 CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops) 39 : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors), 40 pointers(numTensors, std::vector<Value>(numLoops)), 41 indices(numTensors, std::vector<Value>(numLoops)), 42 highs(numTensors, std::vector<Value>(numLoops)), 43 pidxs(numTensors, std::vector<Value>(numLoops)), 44 idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(), 45 curVecLength(1), curVecMask() {} 46 /// Sparsification options. 47 SparsificationOptions options; 48 /// Universal dense indices and upper bounds (by index). The loops array 49 /// is updated with the value of the universal dense index in the current 50 /// loop. The sizes array is set once with the inferred dimension sizes. 51 std::vector<Value> loops; 52 std::vector<Value> sizes; 53 /// Buffers for storing dense and sparse numerical values (by tensor). 54 /// This array is set once during bufferization of all tensors. 55 std::vector<Value> buffers; 56 /// Sparse storage schemes (1-D): pointers and indices (by tensor and index). 57 /// This array is set once during bufferization of all sparse tensors. 58 std::vector<std::vector<Value>> pointers; 59 std::vector<std::vector<Value>> indices; 60 /// Sparse iteration information (by tensor and index). These arrays 61 /// are updated to remain current within the current loop. 62 std::vector<std::vector<Value>> highs; 63 std::vector<std::vector<Value>> pidxs; 64 std::vector<std::vector<Value>> idxs; 65 /// Current reduction, updated during code generation. When indices of a 66 /// reduction are exhausted, all inner loops can "scalarize" the reduction. 67 // TODO: currently only done for (a chain of) innermost for-loops, where it 68 // is most effective; we could generalize to more outer and while-loops. 69 unsigned redExp; 70 Value redVal; 71 // Current vector length and mask. 72 unsigned curVecLength; 73 Value curVecMask; 74 }; 75 76 } // namespace 77 78 // Helper method to apply dimension ordering permutation. 79 static unsigned perm(SparseTensorEncodingAttr &enc, unsigned d) { 80 if (enc) { 81 auto order = enc.getDimOrdering(); 82 if (order) { 83 assert(order.isPermutation()); 84 return order.getDimPosition(d); 85 } 86 } 87 return d; 88 } 89 90 // Helper method to translate dim level type to internal representation. 91 static Dim toDim(SparseTensorEncodingAttr &enc, unsigned d) { 92 if (enc) { 93 SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d]; 94 if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed) 95 return Dim::kSparse; 96 if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton) 97 return Dim::kSingle; 98 } 99 return Dim::kDense; 100 } 101 102 /// Helper method to inspect affine expressions. Rejects cases where the 103 /// same index is used in more than one dimension of a tensor. Also rejects 104 /// affine expressions that are not a direct index for annotated tensors. 105 /// TODO: accept more affine cases for sparse tensors 106 static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a, Dim dim, 107 bool isDense) { 108 switch (a.getKind()) { 109 case AffineExprKind::DimId: { 110 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 111 if (!merger.isDim(tensor, idx, Dim::kUndef)) 112 return false; // used more than once 113 merger.setDim(tensor, idx, dim); 114 return true; 115 } 116 case AffineExprKind::Add: 117 case AffineExprKind::Mul: { 118 if (!isDense) 119 return false; 120 auto binOp = a.cast<AffineBinaryOpExpr>(); 121 return findAffine(merger, tensor, binOp.getLHS(), dim, isDense) && 122 findAffine(merger, tensor, binOp.getRHS(), dim, isDense); 123 } 124 case AffineExprKind::Constant: 125 return isDense; 126 default: 127 return false; 128 } 129 } 130 131 /// Helper method to inspect sparse encodings in the tensor types. 132 /// Fills the per-dimension sparsity information for all tensors. 133 /// Returns true if the sparse annotations and affine subscript 134 /// expressions of all tensors are admissable. Returns false if 135 /// no annotations are found or inadmissable constructs occur. 136 static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) { 137 bool annotated = false; 138 for (OpOperand *t : op.getInputAndOutputOperands()) { 139 auto map = op.getTiedIndexingMap(t); 140 auto enc = getSparseTensorEncoding(t->get().getType()); 141 if (enc) 142 annotated = true; 143 assert(map.getNumResults() == op.getRank(t)); 144 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 145 unsigned tensor = t->getOperandNumber(); 146 AffineExpr a = map.getResult(perm(enc, d)); 147 if (!findAffine(merger, tensor, a, toDim(enc, d), !enc)) 148 return false; // inadmissable affine expression 149 } 150 } 151 return annotated; 152 } 153 154 /// A DFS helper to compute a topological sort. Note that recursion is 155 /// bounded by the number of implicit loops, which is always small. 156 /// Returns false when a cycle is detected. 157 static bool topSortDFS(unsigned i, std::vector<unsigned> &visit, 158 std::vector<unsigned> &topSort, 159 std::vector<std::vector<bool>> &adjM) { 160 if (visit[i] != 0) 161 return visit[i] != 1; // 1 denotes cycle! 162 visit[i] = 1; 163 for (unsigned j = 0, e = visit.size(); j < e; j++) 164 if (adjM[i][j]) 165 if (!topSortDFS(j, visit, topSort, adjM)) 166 return false; 167 visit[i] = 2; 168 topSort.push_back(i); 169 return true; 170 } 171 172 /// Helper method to add all constraints from the indices in one affine 173 /// expression before all indices in the other affine expression. For 174 /// example i0+i1 < i2+i3+1 yields i0<i2, i0<i3, i1<i2, and i1<i3. 175 static void addAffineOrderings(std::vector<std::vector<bool>> &adjM, 176 AffineExpr a, AffineExpr b, unsigned fidx) { 177 switch (a.getKind()) { 178 case AffineExprKind::DimId: { 179 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 180 if (b) 181 addAffineOrderings(adjM, b, AffineExpr(), idx); 182 else 183 adjM[fidx][idx] = true; 184 break; 185 } 186 case AffineExprKind::Add: 187 case AffineExprKind::Mul: { 188 auto binOp = a.cast<AffineBinaryOpExpr>(); 189 addAffineOrderings(adjM, binOp.getLHS(), b, fidx); 190 addAffineOrderings(adjM, binOp.getRHS(), b, fidx); 191 break; 192 } 193 default: 194 break; 195 } 196 } 197 198 /// Computes a topologically sorted iteration graph for the linalg operation. 199 /// Ensures all tensors are visited in natural index order. This is essential 200 /// for sparse storage formats since these only support access along fixed 201 /// dimensions. Even for dense storage formats, however, the natural index 202 /// order yields innermost unit-stride access with better spatial locality. 203 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op, 204 std::vector<unsigned> &topSort, 205 unsigned mask) { 206 // Set up an n x n from/to adjacency matrix of the iteration graph 207 // for the implicit loop indices i_0 .. i_n-1. 208 unsigned n = op.getNumLoops(); 209 std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false)); 210 211 // Iterate over the indexing maps of every tensor in the tensor expression. 212 for (OpOperand *t : op.getInputAndOutputOperands()) { 213 auto map = op.getTiedIndexingMap(t); 214 auto enc = getSparseTensorEncoding(t->get().getType()); 215 assert(map.getNumDims() == n); 216 // Skip dense tensor constraints when not requested. 217 if (!(mask & SortMask::kIncludeDense) && !enc) 218 continue; 219 // Each tensor expression and optional dimension ordering (row-major 220 // by default) puts an ordering constraint on the loop indices. For 221 // example, the tensor expresion A_ijk forces the ordering i < j < k 222 // on the loop indices if no explicit dimension ordering is given. 223 for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) { 224 AffineExpr f = map.getResult(perm(enc, d - 1)); 225 AffineExpr t = map.getResult(perm(enc, d)); 226 addAffineOrderings(adjM, f, t, 0); 227 } 228 // Push unrelated loops into sparse iteration space, so these 229 // will be skipped more often. 230 if (mask & SortMask::kIncludeUndef) { 231 unsigned tensor = t->getOperandNumber(); 232 for (unsigned i = 0; i < n; i++) 233 if (merger.isDim(tensor, i, Dim::kSparse)) 234 for (unsigned j = 0; j < n; j++) 235 if (merger.isDim(tensor, j, Dim::kUndef)) 236 adjM[i][j] = true; 237 } 238 } 239 240 // Topologically sort the iteration graph to determine loop order. 241 // Report failure for a cyclic iteration graph. 242 topSort.clear(); 243 topSort.reserve(n); 244 std::vector<unsigned> visit(n, 0); 245 for (unsigned i = 0; i < n; i++) 246 if (visit[i] == 0) 247 if (!topSortDFS(i, visit, topSort, adjM)) 248 return false; // cycle! 249 std::reverse(std::begin(topSort), std::end(topSort)); 250 return true; 251 } 252 253 /// Returns true when the tensor expression is admissable for codegen. 254 /// Since all sparse input tensors are admissable, we just need to check 255 /// whether the output tensor in the tensor expression codegen is admissable. 256 static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op, 257 unsigned exp) { 258 OpOperand *lhs = op.getOutputOperand(0); 259 unsigned tensor = lhs->getOperandNumber(); 260 auto enc = getSparseTensorEncoding(lhs->get().getType()); 261 // An non-annotated output tensor is assumed dense, and becomes a random 262 // access n-dim memref. Admissable since insertions cannot occur. 263 if (!enc) 264 return true; 265 // An all-dense annotated "sparse" output tensor becomes a linearized random 266 // access 1-dim memref. Also admissable since insertions cannot occur. 267 bool allDense = true; 268 unsigned numLoops = op.iterator_types().getValue().size(); 269 for (unsigned i = 0; i < numLoops; i++) 270 if (merger.isDim(tensor, i, Dim::kSparse)) { 271 allDense = false; 272 break; 273 } 274 if (allDense) 275 return true; 276 // A tensor expression with a sparse output tensor that changes its values 277 // but not its nonzero structure, an operation called "simply dynamic" in 278 // [Bik96,Ch9], is also admissable without special codegen. 279 if (merger.isConjunction(tensor, exp)) 280 return true; 281 // Reject for now since this requires changes to the nonzero structure. 282 // TODO: implement "workspaces" [Kjolstad2019] 283 return false; 284 } 285 286 /// Maps sparse integer option to actual integral storage type. 287 static Type genIntType(PatternRewriter &rewriter, unsigned width) { 288 if (width == 0) 289 return rewriter.getIndexType(); 290 return rewriter.getIntegerType(width); 291 } 292 293 /// Detects in-place annotation on tensor argument. 294 static bool getInPlace(Value val) { 295 if (auto arg = val.dyn_cast<BlockArgument>()) 296 if (auto funcOp = dyn_cast<FuncOp>(arg.getOwner()->getParentOp())) 297 if (auto attr = funcOp.getArgAttrOfType<BoolAttr>( 298 arg.getArgNumber(), linalg::LinalgDialect::kInplaceableAttrName)) 299 return attr.getValue(); 300 return false; 301 } 302 303 /// Generates buffer for the output tensor. 304 static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter, 305 linalg::GenericOp op, MemRefType denseTp, 306 ArrayRef<Value> args) { 307 Location loc = op.getLoc(); 308 Value tensor = op.getOutputOperand(0)->get(); 309 // The output tensor simply could materialize from the buffer that will 310 // be generated for the tensor present in the outs() clause. This has 311 // the major advantage that the sparse kernel only updates the nonzero 312 // positions for the output tensor. 313 if (getInPlace(tensor)) 314 return rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 315 // By default, a new buffer is allocated which is initialized to the 316 // tensor defined in the outs() clause. This is always correct but 317 // introduces a dense initialization component that may negatively 318 // impact the running complexity of the sparse kernel. 319 Value init = rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 320 Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args); 321 rewriter.create<memref::CopyOp>(loc, init, alloc); 322 return alloc; 323 } 324 325 /// Local bufferization of all dense and sparse data structures. 326 /// This code enables testing the first prototype sparse compiler. 327 // TODO: replace this with a proliferated bufferization strategy 328 static bool genBuffers(Merger &merger, CodeGen &codegen, 329 PatternRewriter &rewriter, linalg::GenericOp op) { 330 Location loc = op.getLoc(); 331 assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1); 332 // For every tensor, find lower and upper bound on dimensions, set the 333 // same bounds on loop indices, and obtain dense or sparse buffer(s). 334 SmallVector<Value, 4> args; 335 for (OpOperand *t : op.getInputAndOutputOperands()) { 336 unsigned tensor = t->getOperandNumber(); 337 auto shape = op.getShape(t); 338 auto map = op.getTiedIndexingMap(t); 339 auto enc = getSparseTensorEncoding(t->get().getType()); 340 // Scan all dimensions of current tensor. 341 args.clear(); 342 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 343 AffineExpr a = map.getResult(perm(enc, d)); 344 if (a.getKind() != AffineExprKind::DimId) 345 continue; // compound 346 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 347 // Handle sparse storage schemes. 348 if (merger.isDim(tensor, idx, Dim::kSparse)) { 349 auto dynShape = {ShapedType::kDynamicSize}; 350 auto ptrTp = MemRefType::get( 351 dynShape, genIntType(rewriter, enc.getPointerBitWidth())); 352 auto indTp = MemRefType::get( 353 dynShape, genIntType(rewriter, enc.getIndexBitWidth())); 354 Value dim = rewriter.create<ConstantIndexOp>(loc, d); 355 // Generate sparse primitives to obtains pointer and indices. 356 codegen.pointers[tensor][idx] = 357 rewriter.create<ToPointersOp>(loc, ptrTp, t->get(), dim); 358 codegen.indices[tensor][idx] = 359 rewriter.create<ToIndicesOp>(loc, indTp, t->get(), dim); 360 } 361 // Find upper bound in current dimension. 362 unsigned p = perm(enc, d); 363 Value up = linalg::createOrFoldDimOp(rewriter, loc, t->get(), p); 364 if (shape[p] == MemRefType::kDynamicSize) 365 args.push_back(up); 366 assert(codegen.highs[tensor][idx] == nullptr); 367 codegen.sizes[idx] = codegen.highs[tensor][idx] = up; 368 } 369 // Perform the required bufferization. Dense inputs materialize 370 // from the input tensors. Dense outputs need special handling. 371 // Sparse inputs use sparse primitives to obtain the values. 372 // We also accept in-place all-dense annotated "sparse" outputs. 373 Type elementType = getElementTypeOrSelf(t->get().getType()); 374 if (!enc) { 375 // Non-annotated dense tensors. 376 auto denseTp = MemRefType::get(shape, elementType); 377 if (tensor < op.getNumInputs()) 378 codegen.buffers[tensor] = 379 rewriter.create<memref::BufferCastOp>(loc, denseTp, t->get()); 380 else 381 codegen.buffers[tensor] = 382 genOutputBuffer(codegen, rewriter, op, denseTp, args); 383 } else { 384 // Annotated sparse tensors. 385 if (tensor == op.getNumInputs() && !getInPlace(t->get())) 386 return false; // reject output if not in-place 387 auto dynShape = {ShapedType::kDynamicSize}; 388 auto sparseTp = MemRefType::get(dynShape, elementType); 389 codegen.buffers[tensor] = 390 rewriter.create<ToValuesOp>(loc, sparseTp, t->get()); 391 } 392 } 393 return true; 394 } 395 396 /// Constructs vector type. 397 static VectorType vectorType(CodeGen &codegen, Type etp) { 398 return VectorType::get(codegen.curVecLength, etp); 399 } 400 401 /// Constructs vector type from pointer. 402 static VectorType vectorType(CodeGen &codegen, Value ptr) { 403 return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType()); 404 } 405 406 /// Constructs vector iteration mask. 407 static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter, 408 Value iv, Value lo, Value hi, Value step) { 409 Location loc = iv.getLoc(); 410 VectorType mtp = vectorType(codegen, rewriter.getIntegerType(1)); 411 // Special case if the vector length evenly divides the trip count (for 412 // example, "for i = 0, 128, 16"). A constant all-true mask is generated 413 // so that all subsequent masked memory operations are immediately folded 414 // into unconditional memory operations. 415 IntegerAttr loInt, hiInt, stepInt; 416 if (matchPattern(lo, m_Constant(&loInt)) && 417 matchPattern(hi, m_Constant(&hiInt)) && 418 matchPattern(step, m_Constant(&stepInt))) { 419 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) 420 return rewriter.create<vector::BroadcastOp>( 421 loc, mtp, rewriter.create<ConstantIntOp>(loc, 1, 1)); 422 } 423 // Otherwise, generate a vector mask that avoids overrunning the upperbound 424 // during vector execution. Here we rely on subsequent loop optimizations to 425 // avoid executing the mask in all iterations, for example, by splitting the 426 // loop into an unconditional vector loop and a scalar cleanup loop. 427 auto minMap = AffineMap::get( 428 /*dimCount=*/2, /*symbolCount=*/1, 429 {rewriter.getAffineSymbolExpr(0), 430 rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)}, 431 rewriter.getContext()); 432 Value end = 433 rewriter.createOrFold<AffineMinOp>(loc, minMap, ValueRange{hi, iv, step}); 434 return rewriter.create<vector::CreateMaskOp>(loc, mtp, end); 435 } 436 437 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi]. 438 static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter, 439 Value ptr, ArrayRef<Value> args) { 440 Location loc = ptr.getLoc(); 441 VectorType vtp = vectorType(codegen, ptr); 442 Value pass = rewriter.create<ConstantOp>(loc, vtp, rewriter.getZeroAttr(vtp)); 443 if (args.back().getType().isa<VectorType>()) { 444 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 445 Value indexVec = args.back(); 446 scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0); 447 return rewriter.create<vector::GatherOp>( 448 loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass); 449 } 450 return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args, 451 codegen.curVecMask, pass); 452 } 453 454 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs. 455 static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter, 456 Value rhs, Value ptr, ArrayRef<Value> args) { 457 Location loc = ptr.getLoc(); 458 if (args.back().getType().isa<VectorType>()) { 459 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 460 Value indexVec = args.back(); 461 scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0); 462 rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, 463 codegen.curVecMask, rhs); 464 return; 465 } 466 rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask, 467 rhs); 468 } 469 470 /// Generates a vectorized invariant. Here we rely on subsequent loop 471 /// optimizations to hoist the invariant broadcast out of the vector loop. 472 static Value genVectorInvariantValue(CodeGen &codegen, 473 PatternRewriter &rewriter, Value val) { 474 VectorType vtp = vectorType(codegen, val.getType()); 475 return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val); 476 } 477 478 /// Generates an affine expression. 479 // 480 // TODO: generalize for sparse tensor subscripts 481 // 482 static Value genAffine(CodeGen &codegen, PatternRewriter &rewriter, 483 AffineExpr a, Location loc) { 484 switch (a.getKind()) { 485 case AffineExprKind::DimId: { 486 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 487 return codegen.loops[idx]; // universal dense index 488 } 489 case AffineExprKind::Add: { 490 auto binOp = a.cast<AffineBinaryOpExpr>(); 491 return rewriter.create<AddIOp>( 492 loc, genAffine(codegen, rewriter, binOp.getLHS(), loc), 493 genAffine(codegen, rewriter, binOp.getRHS(), loc)); 494 } 495 case AffineExprKind::Mul: { 496 auto binOp = a.cast<AffineBinaryOpExpr>(); 497 return rewriter.create<MulIOp>( 498 loc, genAffine(codegen, rewriter, binOp.getLHS(), loc), 499 genAffine(codegen, rewriter, binOp.getRHS(), loc)); 500 } 501 case AffineExprKind::Constant: { 502 int64_t c = a.cast<AffineConstantExpr>().getValue(); 503 return rewriter.create<ConstantIndexOp>(loc, c); 504 } 505 default: 506 llvm_unreachable("unexpected affine subscript"); 507 } 508 } 509 510 /// Generates subscript for load/store on a dense or sparse tensor. 511 static Value genSubscript(CodeGen &codegen, PatternRewriter &rewriter, 512 linalg::GenericOp op, OpOperand *t, 513 SmallVector<Value, 4> &args) { 514 unsigned tensor = t->getOperandNumber(); 515 auto map = op.getTiedIndexingMap(t); 516 auto enc = getSparseTensorEncoding(t->get().getType()); 517 unsigned rank = map.getNumResults(); 518 if (enc) { 519 // Note that currently, all sparse subscripts are simple. 520 // TODO: accept affine too? 521 unsigned idx = map.getDimPosition(perm(enc, rank - 1)); 522 assert(codegen.pidxs[tensor][idx] != nullptr); 523 args.push_back(codegen.pidxs[tensor][idx]); // position index 524 } else { 525 for (unsigned d = 0; d < rank; d++) { 526 AffineExpr a = map.getResult(perm(enc, d)); 527 args.push_back(genAffine(codegen, rewriter, a, op.getLoc())); 528 } 529 } 530 return codegen.buffers[tensor]; 531 } 532 533 /// Generates a load on a dense or sparse tensor. 534 static Value genTensorLoad(Merger &merger, CodeGen &codegen, 535 PatternRewriter &rewriter, linalg::GenericOp op, 536 unsigned exp) { 537 // Test if the load was hoisted to a higher loop nest. 538 Value val = merger.exp(exp).val; 539 if (val) { 540 if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>()) 541 return genVectorInvariantValue(codegen, rewriter, val); 542 return val; 543 } 544 // Actual load. 545 SmallVector<Value, 4> args; 546 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 547 Value ptr = genSubscript(codegen, rewriter, op, t, args); 548 if (codegen.curVecLength > 1) 549 return genVectorLoad(codegen, rewriter, ptr, args); 550 return rewriter.create<memref::LoadOp>(op.getLoc(), ptr, args); 551 } 552 553 /// Generates a store on a dense or sparse tensor. 554 static void genTensorStore(Merger &merger, CodeGen &codegen, 555 PatternRewriter &rewriter, linalg::GenericOp op, 556 Value rhs) { 557 // Test if this is a scalarized reduction. 558 if (codegen.redVal) { 559 if (codegen.curVecLength > 1) 560 rhs = rewriter.create<SelectOp>(op.getLoc(), codegen.curVecMask, rhs, 561 codegen.redVal); 562 codegen.redVal = rhs; 563 return; 564 } 565 // Actual store. 566 SmallVector<Value, 4> args; 567 OpOperand *t = op.getOutputOperand(0); 568 Value ptr = genSubscript(codegen, rewriter, op, t, args); 569 if (codegen.curVecLength > 1) 570 genVectorStore(codegen, rewriter, rhs, ptr, args); 571 else 572 rewriter.create<memref::StoreOp>(op.getLoc(), rhs, ptr, args); 573 } 574 575 /// Generates a pointer/index load from the sparse storage scheme. Narrower 576 /// data types need to be zero extended before casting the value into the 577 /// index type used for looping and indexing. 578 static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc, 579 Value ptr, Value s) { 580 // See https://llvm.org/docs/GetElementPtr.html for some background on 581 // the complications described below. 582 if (codegen.curVecLength > 1) { 583 // Since the index vector is used in a subsequent gather/scatter operations, 584 // which effectively defines an unsigned pointer + signed index, we must 585 // zero extend the vector to an index width. For 8-bit and 16-bit values, 586 // an 32-bit index width suffices. For 32-bit values, zero extending the 587 // elements into 64-bit loses some performance since the 32-bit indexed 588 // gather/scatter is more efficient than the 64-bit index variant (if the 589 // negative 32-bit index space is unused, the enableSIMDIndex32 flag can 590 // preserve this performance). For 64-bit values, there is no good way 591 // to state that the indices are unsigned, with creates the potential of 592 // incorrect address calculations in the unlikely case we need such 593 // extremely large offsets. 594 Type etp = ptr.getType().cast<MemRefType>().getElementType(); 595 Value vload = genVectorLoad(codegen, rewriter, ptr, {s}); 596 if (!etp.isa<IndexType>()) { 597 if (etp.getIntOrFloatBitWidth() < 32) 598 vload = rewriter.create<ZeroExtendIOp>( 599 loc, vload, vectorType(codegen, rewriter.getIntegerType(32))); 600 else if (etp.getIntOrFloatBitWidth() < 64 && 601 !codegen.options.enableSIMDIndex32) 602 vload = rewriter.create<ZeroExtendIOp>( 603 loc, vload, vectorType(codegen, rewriter.getIntegerType(64))); 604 } 605 return vload; 606 } 607 // For the scalar case, we simply zero extend narrower indices into 64-bit 608 // values before casting to index without a performance penalty. Here too, 609 // however, indices that already are 64-bit, in theory, cannot express the 610 // full range as explained above. 611 Value load = rewriter.create<memref::LoadOp>(loc, ptr, s); 612 if (!load.getType().isa<IndexType>()) { 613 if (load.getType().getIntOrFloatBitWidth() < 64) 614 load = rewriter.create<ZeroExtendIOp>(loc, load, 615 rewriter.getIntegerType(64)); 616 load = rewriter.create<IndexCastOp>(loc, load, rewriter.getIndexType()); 617 } 618 return load; 619 } 620 621 /// Generates an invariant value. 622 static Value genInvariantValue(Merger &merger, CodeGen &codegen, 623 PatternRewriter &rewriter, unsigned exp) { 624 Value val = merger.exp(exp).val; 625 if (codegen.curVecLength > 1) 626 return genVectorInvariantValue(codegen, rewriter, val); 627 return val; 628 } 629 630 /// Generates an address computation "sz * p + i". 631 static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter, 632 Location loc, Value size, Value p, Value i) { 633 Value mul = rewriter.create<MulIOp>(loc, size, p); 634 if (auto vtp = i.getType().dyn_cast<VectorType>()) { 635 Value inv = rewriter.create<IndexCastOp>(loc, mul, vtp.getElementType()); 636 mul = genVectorInvariantValue(codegen, rewriter, inv); 637 } 638 return rewriter.create<AddIOp>(loc, mul, i); 639 } 640 641 /// Generates start of a reduction. 642 static Value genReductionStart(Merger &merger, CodeGen &codegen, 643 PatternRewriter &rewriter, 644 linalg::GenericOp op) { 645 if (codegen.redVal) 646 return codegen.redVal; // chained with previous for-loop 647 if (codegen.curVecLength > 1) { 648 // TODO: assumes + reductions for now 649 VectorType vtp = vectorType(codegen, codegen.buffers[codegen.redExp]); 650 return rewriter.create<ConstantOp>(op.getLoc(), vtp, 651 rewriter.getZeroAttr(vtp)); 652 } 653 return genTensorLoad(merger, codegen, rewriter, op, codegen.redExp); 654 } 655 656 /// Generates end of a reduction. 657 static void genReductionEnd(Merger &merger, CodeGen &codegen, 658 PatternRewriter &rewriter, linalg::GenericOp op) { 659 Value red = codegen.redVal; 660 if (!red) 661 return; 662 assert(codegen.curVecLength == 1); 663 codegen.redVal = merger.exp(codegen.redExp).val = Value(); // end chain 664 if (auto vtp = red.getType().dyn_cast<VectorType>()) { 665 // TODO: assumes + reductions for now 666 StringAttr kind = rewriter.getStringAttr("add"); 667 Value ld = genTensorLoad(merger, codegen, rewriter, op, codegen.redExp); 668 // Integer reductions don't accept an accumulator. 669 if (vtp.getElementType().isa<IntegerType>()) { 670 red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(), 671 kind, red, ValueRange{}); 672 red = rewriter.create<AddIOp>(op.getLoc(), red, ld); 673 } else { 674 red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(), 675 kind, red, ld); 676 } 677 } 678 genTensorStore(merger, codegen, rewriter, op, red); 679 } 680 681 /// Recursively generates tensor expression. 682 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 683 linalg::GenericOp op, unsigned exp) { 684 Location loc = op.getLoc(); 685 if (exp == -1u) 686 return Value(); 687 if (merger.exp(exp).kind == Kind::kTensor) 688 return genTensorLoad(merger, codegen, rewriter, op, exp); 689 if (merger.exp(exp).kind == Kind::kInvariant) 690 return genInvariantValue(merger, codegen, rewriter, exp); 691 Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0); 692 Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1); 693 if (merger.exp(exp).kind == Kind::kNegI) { 694 // TODO: no negi in std, need to make zero explicit. 695 Type tp = op.getOutputTensorTypes()[0].getElementType(); 696 v1 = v0; 697 v0 = rewriter.create<ConstantOp>(loc, tp, rewriter.getZeroAttr(tp)); 698 if (codegen.curVecLength > 1) 699 v0 = genVectorInvariantValue(codegen, rewriter, v0); 700 } 701 return merger.buildExp(rewriter, loc, exp, v0, v1); 702 } 703 704 /// Determines if affine expression is invariant. 705 static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a, 706 unsigned ldx, bool &atLevel) { 707 switch (a.getKind()) { 708 case AffineExprKind::DimId: { 709 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 710 if (idx == ldx) 711 atLevel = true; 712 return codegen.loops[idx] != nullptr; // no longer in play? 713 } 714 case AffineExprKind::Add: 715 case AffineExprKind::Mul: { 716 auto binOp = a.cast<AffineBinaryOpExpr>(); 717 return isInvariantAffine(codegen, binOp.getLHS(), ldx, atLevel) && 718 isInvariantAffine(codegen, binOp.getRHS(), ldx, atLevel); 719 } 720 default: 721 return true; 722 } 723 } 724 725 /// Hoists loop invariant tensor loads for which indices have been exhausted. 726 static void genInvariants(Merger &merger, CodeGen &codegen, 727 PatternRewriter &rewriter, linalg::GenericOp op, 728 unsigned exp, unsigned ldx, bool hoist) { 729 if (exp == -1u) 730 return; 731 if (merger.exp(exp).kind == Kind::kTensor) { 732 // Inspect tensor indices. 733 bool atLevel = ldx == -1u; 734 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 735 auto map = op.getTiedIndexingMap(t); 736 auto enc = getSparseTensorEncoding(t->get().getType()); 737 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 738 AffineExpr a = map.getResult(perm(enc, d)); 739 if (!isInvariantAffine(codegen, a, ldx, atLevel)) 740 return; // still in play 741 } 742 // All exhausted at this level (atLevel denotes exactly at this level). 743 OpOperand *lhs = op.getOutputOperand(0); 744 if (lhs == t) { 745 codegen.redExp = hoist ? exp : -1u; 746 } else if (atLevel) { 747 merger.exp(exp).val = 748 hoist ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value(); 749 } 750 } else if (merger.exp(exp).kind != Kind::kInvariant) { 751 // Traverse into the binary operations. Note that we only hoist 752 // tensor loads, since subsequent MLIR/LLVM passes know how to 753 // deal with all other kinds of derived loop invariants. 754 unsigned e0 = merger.exp(exp).children.e0; 755 unsigned e1 = merger.exp(exp).children.e1; 756 genInvariants(merger, codegen, rewriter, op, e0, ldx, hoist); 757 genInvariants(merger, codegen, rewriter, op, e1, ldx, hoist); 758 } 759 } 760 761 /// Generates initialization code for the subsequent loop sequence at 762 /// current index level. Returns true if the loop sequence needs to 763 /// maintain the universal index. 764 static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 765 linalg::GenericOp op, std::vector<unsigned> &topSort, 766 unsigned at, llvm::BitVector &inits) { 767 bool needsUniv = false; 768 Location loc = op.getLoc(); 769 unsigned idx = topSort[at]; 770 771 // Initialize sparse positions. 772 for (unsigned b = 0, be = inits.size(); b < be; b++) { 773 if (inits[b]) { 774 unsigned tensor = merger.tensor(b); 775 assert(idx == merger.index(b)); 776 if (merger.isDim(b, Dim::kSparse)) { 777 // Initialize sparse index. 778 unsigned pat = at; 779 for (; pat != 0; pat--) { 780 if (codegen.pidxs[tensor][topSort[pat - 1]]) 781 break; 782 } 783 Value ptr = codegen.pointers[tensor][idx]; 784 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 785 Value p0 = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0) 786 : codegen.pidxs[tensor][topSort[pat - 1]]; 787 codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0); 788 Value p1 = rewriter.create<AddIOp>(loc, p0, one); 789 codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1); 790 } else { 791 // Dense index still in play. 792 needsUniv = true; 793 } 794 } 795 } 796 797 // Initialize the universal dense index. 798 codegen.loops[idx] = rewriter.create<ConstantIndexOp>(loc, 0); 799 return needsUniv; 800 } 801 802 /// Returns vectorization strategy. Any implicit inner loop in the Linalg 803 /// operation is a candidate. Whether it is actually converted to SIMD code 804 /// depends on the requested strategy. 805 static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) { 806 switch (codegen.options.vectorizationStrategy) { 807 case SparseVectorizationStrategy::kNone: 808 return false; 809 case SparseVectorizationStrategy::kDenseInnerLoop: 810 return isInner && !isSparse; 811 case SparseVectorizationStrategy::kAnyStorageInnerLoop: 812 return isInner; 813 } 814 llvm_unreachable("unexpected vectorization strategy"); 815 } 816 817 /// Returns parallelization strategy. Any implicit loop in the Linalg operation 818 /// that is marked "parallel" is a candidate. Whether it is actually converted 819 /// to a parallel operation depends on the requested strategy. 820 static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction, 821 bool isSparse, bool isVector) { 822 switch (codegen.options.parallelizationStrategy) { 823 case SparseParallelizationStrategy::kNone: 824 return false; 825 case SparseParallelizationStrategy::kDenseOuterLoop: 826 return isOuter && !isSparse && !isReduction && !isVector; 827 case SparseParallelizationStrategy::kAnyStorageOuterLoop: 828 return isOuter && !isReduction && !isVector; 829 case SparseParallelizationStrategy::kDenseAnyLoop: 830 return !isSparse && !isReduction && !isVector; 831 case SparseParallelizationStrategy::kAnyStorageAnyLoop: 832 return !isReduction && !isVector; 833 } 834 llvm_unreachable("unexpected parallelization strategy"); 835 } 836 837 /// Checks unit strides for dense tensors. The iteration graph may have ignored 838 /// dense access patterns in order to avoid cycles (sparse access patterns are 839 /// always placed innermost), but that means dense access has become strided. 840 /// For now, we reject vectorization of such cases. 841 /// TODO: implement strided load/stores on dense arrays 842 static bool denseUnitStrides(Merger &merger, linalg::GenericOp op, 843 unsigned ldx) { 844 for (OpOperand *t : op.getInputAndOutputOperands()) { 845 if (!getSparseTensorEncoding(t->get().getType())) { 846 auto map = op.getTiedIndexingMap(t); 847 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 848 AffineExpr a = map.getResult(d); 849 if (a.getKind() != AffineExprKind::DimId) 850 return false; // very conservative 851 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 852 if (idx == ldx && d != rank - 1) 853 return false; 854 } 855 } 856 } 857 return true; 858 } 859 860 /// Generates a for-loop on a single index. 861 static Operation *genFor(Merger &merger, CodeGen &codegen, 862 PatternRewriter &rewriter, linalg::GenericOp op, 863 bool isOuter, bool isInner, unsigned idx, 864 llvm::BitVector &indices) { 865 unsigned fb = indices.find_first(); 866 unsigned tensor = merger.tensor(fb); 867 assert(idx == merger.index(fb)); 868 auto iteratorTypes = op.iterator_types().getValue(); 869 bool isReduction = isReductionIterator(iteratorTypes[idx]); 870 bool isSparse = merger.isDim(fb, Dim::kSparse); 871 bool isVector = isVectorFor(codegen, isInner, isSparse) && 872 denseUnitStrides(merger, op, idx); 873 bool isParallel = 874 isParallelFor(codegen, isOuter, isReduction, isSparse, isVector); 875 876 // Prepare vector length. 877 if (isVector) 878 codegen.curVecLength = codegen.options.vectorLength; 879 880 // Loop bounds and increment. 881 Location loc = op.getLoc(); 882 Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx]; 883 Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx]; 884 Value step = rewriter.create<ConstantIndexOp>(loc, codegen.curVecLength); 885 886 // Emit a parallel loop. 887 if (isParallel) { 888 assert(!isVector); 889 scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step); 890 if (isSparse) 891 codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0]; 892 else 893 codegen.loops[idx] = parOp.getInductionVars()[0]; 894 rewriter.setInsertionPointToStart(parOp.getBody()); 895 return parOp; 896 } 897 898 // Emit a sequential loop, potentially with a scalarized reduction. 899 bool scalarRed = isInner && codegen.redExp != -1u; 900 SmallVector<Value, 4> operands; 901 if (scalarRed) { 902 Value load = genReductionStart(merger, codegen, rewriter, op); 903 operands.push_back(load); 904 } 905 scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands); 906 if (scalarRed) { 907 codegen.redVal = merger.exp(codegen.redExp).val = 908 forOp.getRegionIterArgs().front(); 909 } 910 // Assign induction variable to sparse or dense index. 911 Value iv = forOp.getInductionVar(); 912 if (isSparse) 913 codegen.pidxs[tensor][idx] = iv; 914 else 915 codegen.loops[idx] = iv; 916 rewriter.setInsertionPointToStart(forOp.getBody()); 917 // Share vector iteration mask between all subsequent loads/stores. 918 if (isVector) 919 codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step); 920 return forOp; 921 } 922 923 /// Emit a while-loop for co-iteration over multiple indices. 924 static Operation *genWhile(Merger &merger, CodeGen &codegen, 925 PatternRewriter &rewriter, linalg::GenericOp op, 926 unsigned idx, bool needsUniv, 927 llvm::BitVector &indices) { 928 SmallVector<Type, 4> types; 929 SmallVector<Value, 4> operands; 930 // Construct the while-loop with a parameter for each index. 931 Type indexType = rewriter.getIndexType(); 932 for (unsigned b = 0, be = indices.size(); b < be; b++) { 933 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 934 unsigned tensor = merger.tensor(b); 935 assert(idx == merger.index(b)); 936 types.push_back(indexType); 937 assert(codegen.pidxs[tensor][idx].getType().isa<IndexType>() && 938 "type mismatch for sparse index"); 939 operands.push_back(codegen.pidxs[tensor][idx]); 940 } 941 } 942 if (needsUniv) { 943 types.push_back(indexType); 944 assert(codegen.loops[idx].getType().isa<IndexType>() && 945 "type mismatch for universal index"); 946 operands.push_back(codegen.loops[idx]); 947 } 948 Location loc = op.getLoc(); 949 scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands); 950 Block *before = rewriter.createBlock(&whileOp.before(), {}, types); 951 Block *after = rewriter.createBlock(&whileOp.after(), {}, types); 952 953 // Build the "before" region, which effectively consists 954 // of a conjunction of "i < upper" tests on all induction. 955 rewriter.setInsertionPointToStart(&whileOp.before().front()); 956 Value cond; 957 unsigned o = 0; 958 for (unsigned b = 0, be = indices.size(); b < be; b++) { 959 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 960 unsigned tensor = merger.tensor(b); 961 assert(idx == merger.index(b)); 962 Value op1 = before->getArgument(o); 963 Value op2 = codegen.highs[tensor][idx]; 964 Value opc = rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, op1, op2); 965 cond = cond ? rewriter.create<AndOp>(loc, cond, opc) : opc; 966 codegen.pidxs[tensor][idx] = after->getArgument(o++); 967 } 968 } 969 if (needsUniv) 970 codegen.loops[idx] = after->getArgument(o++); 971 assert(o == operands.size()); 972 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 973 rewriter.setInsertionPointToStart(&whileOp.after().front()); 974 return whileOp; 975 } 976 977 /// Generates a for-loop or a while-loop, depending on whether it implements 978 /// singleton iteration or co-iteration over the given conjunction. 979 static Operation *genLoop(Merger &merger, CodeGen &codegen, 980 PatternRewriter &rewriter, linalg::GenericOp op, 981 std::vector<unsigned> &topSort, unsigned at, 982 bool needsUniv, llvm::BitVector &indices) { 983 unsigned idx = topSort[at]; 984 if (indices.count() == 1) { 985 bool isOuter = at == 0; 986 bool isInner = at == topSort.size() - 1; 987 return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx, 988 indices); 989 } 990 genReductionEnd(merger, codegen, rewriter, op); // cannot chain 991 return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices); 992 } 993 994 /// Generates the local variables for this loop, consisting of the sparse 995 /// indices, restored universal dense index, and dense positions. 996 static void genLocals(Merger &merger, CodeGen &codegen, 997 PatternRewriter &rewriter, linalg::GenericOp op, 998 std::vector<unsigned> &topSort, unsigned at, 999 bool needsUniv, llvm::BitVector &locals) { 1000 Location loc = op.getLoc(); 1001 unsigned idx = topSort[at]; 1002 1003 // Initialize sparse indices. 1004 Value min; 1005 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1006 if (locals[b] && merger.isDim(b, Dim::kSparse)) { 1007 unsigned tensor = merger.tensor(b); 1008 assert(idx == merger.index(b)); 1009 Value ptr = codegen.indices[tensor][idx]; 1010 Value s = codegen.pidxs[tensor][idx]; 1011 Value load = genLoad(codegen, rewriter, loc, ptr, s); 1012 codegen.idxs[tensor][idx] = load; 1013 if (!needsUniv) { 1014 if (min) { 1015 Value cmp = 1016 rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, load, min); 1017 min = rewriter.create<SelectOp>(loc, cmp, load, min); 1018 } else { 1019 min = load; 1020 } 1021 } 1022 } 1023 } 1024 1025 // Merge dense universal index over minimum. 1026 if (min) { 1027 assert(!needsUniv); 1028 codegen.loops[idx] = min; 1029 } 1030 1031 // Initialize dense positions. Note that we generate dense indices of the 1032 // output tensor unconditionally, since they may not appear in the lattice, 1033 // but may be needed for linearized codegen. 1034 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1035 if ((locals[b] || merger.isOutTensor(b, idx)) && 1036 merger.isDim(b, Dim::kDense)) { 1037 unsigned tensor = merger.tensor(b); 1038 assert(idx == merger.index(b)); 1039 unsigned pat = at; 1040 for (; pat != 0; pat--) 1041 if (codegen.pidxs[tensor][topSort[pat - 1]]) 1042 break; 1043 Value p = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0) 1044 : codegen.pidxs[tensor][topSort[pat - 1]]; 1045 codegen.pidxs[tensor][idx] = genAddress( 1046 codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]); 1047 } 1048 } 1049 } 1050 1051 /// Generates the induction structure for a while-loop. 1052 static void genWhileInduction(Merger &merger, CodeGen &codegen, 1053 PatternRewriter &rewriter, linalg::GenericOp op, 1054 unsigned idx, bool needsUniv, 1055 llvm::BitVector &induction, ResultRange results) { 1056 Location loc = op.getLoc(); 1057 unsigned o = 0; 1058 SmallVector<Value, 4> operands; 1059 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 1060 for (unsigned b = 0, be = induction.size(); b < be; b++) { 1061 if (induction[b] && merger.isDim(b, Dim::kSparse)) { 1062 unsigned tensor = merger.tensor(b); 1063 assert(idx == merger.index(b)); 1064 Value op1 = codegen.idxs[tensor][idx]; 1065 Value op2 = codegen.loops[idx]; 1066 Value op3 = codegen.pidxs[tensor][idx]; 1067 Value cmp = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2); 1068 Value add = rewriter.create<AddIOp>(loc, op3, one); 1069 operands.push_back(rewriter.create<SelectOp>(loc, cmp, add, op3)); 1070 codegen.pidxs[tensor][idx] = results[o++]; 1071 } 1072 } 1073 if (needsUniv) { 1074 operands.push_back(rewriter.create<AddIOp>(loc, codegen.loops[idx], one)); 1075 codegen.loops[idx] = results[o++]; 1076 } 1077 assert(o == operands.size()); 1078 rewriter.create<scf::YieldOp>(loc, operands); 1079 } 1080 1081 /// Generates a single if-statement within a while-loop. 1082 static scf::IfOp genIf(Merger &merger, CodeGen &codegen, 1083 PatternRewriter &rewriter, linalg::GenericOp op, 1084 unsigned idx, llvm::BitVector &conditions) { 1085 Location loc = op.getLoc(); 1086 Value cond; 1087 for (unsigned b = 0, be = conditions.size(); b < be; b++) { 1088 if (conditions[b]) { 1089 unsigned tensor = merger.tensor(b); 1090 assert(idx == merger.index(b)); 1091 Value clause; 1092 if (merger.isDim(b, Dim::kSparse)) { 1093 Value op1 = codegen.idxs[tensor][idx]; 1094 Value op2 = codegen.loops[idx]; 1095 clause = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2); 1096 } else { 1097 clause = rewriter.create<ConstantIntOp>(loc, 1, 1); // true 1098 } 1099 cond = cond ? rewriter.create<AndOp>(loc, cond, clause) : clause; 1100 } 1101 } 1102 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ true); 1103 rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); 1104 return ifOp; 1105 } 1106 1107 /// Recursively generates code while computing iteration lattices in order 1108 /// to manage the complexity of implementing co-iteration over unions 1109 /// and intersections of sparse iterations spaces. 1110 static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1111 linalg::GenericOp op, std::vector<unsigned> &topSort, 1112 unsigned exp, unsigned at) { 1113 // At each leaf, assign remaining tensor (sub)expression to output tensor. 1114 if (at == topSort.size()) { 1115 Value rhs = genExp(merger, codegen, rewriter, op, exp); 1116 genTensorStore(merger, codegen, rewriter, op, rhs); 1117 return; 1118 } 1119 assert(codegen.curVecLength == 1); 1120 1121 // Construct iteration lattices for current loop index, with L0 at top. 1122 // Then emit initialization code for the loop sequence at this level. 1123 // We maintain the universal dense index if dense indices are still 1124 // in play for a non-singleton loop sequence. 1125 Location loc = op.getLoc(); 1126 unsigned idx = topSort[at]; 1127 unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx)); 1128 unsigned lsize = merger.set(lts).size(); 1129 assert(lsize != 0); 1130 unsigned l0 = merger.set(lts)[0]; 1131 unsigned ldx = at == 0 ? -1u : topSort[at - 1]; 1132 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/true); 1133 bool needsUniv = false; 1134 if (genInit(merger, codegen, rewriter, op, topSort, at, 1135 merger.lat(l0).bits)) { 1136 // Maintain the universal index only if it is actually 1137 // consumed by a subsequent lattice point. 1138 for (unsigned i = 1; i < lsize; i++) { 1139 unsigned li = merger.set(lts)[i]; 1140 if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) { 1141 needsUniv = true; 1142 break; 1143 } 1144 } 1145 } 1146 1147 // Emit a loop for every lattice point L0 >= Li. 1148 for (unsigned i = 0; i < lsize; i++) { 1149 unsigned li = merger.set(lts)[i]; 1150 1151 // Emit loop. 1152 codegen.curVecLength = 1; 1153 llvm::BitVector indices = merger.lat(li).simple; 1154 Operation *loop = 1155 genLoop(merger, codegen, rewriter, op, topSort, at, needsUniv, indices); 1156 genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, 1157 merger.lat(li).bits); 1158 1159 // Visit all lattices points with Li >= Lj to generate the 1160 // loop-body, possibly with if statements for coiteration. 1161 bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr; 1162 for (unsigned j = 0; j < lsize; j++) { 1163 unsigned lj = merger.set(lts)[j]; 1164 unsigned ej = merger.lat(lj).exp; 1165 if (li == lj || merger.latGT(li, lj)) { 1166 // Recurse into body of each branch. 1167 if (isWhile) { 1168 scf::IfOp ifOp = 1169 genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple); 1170 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1171 rewriter.setInsertionPointToStart(&ifOp.elseRegion().front()); 1172 } else { 1173 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1174 } 1175 } 1176 } 1177 1178 // Wrap-up induction and restore insertion point. 1179 if (isWhile) { 1180 scf::WhileOp whileOp = cast<scf::WhileOp>(loop); 1181 rewriter.setInsertionPointToEnd(&whileOp.after().front()); 1182 genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv, 1183 merger.lat(li).bits, whileOp.results()); 1184 } else { 1185 needsUniv = false; 1186 if (codegen.redVal) { 1187 rewriter.create<scf::YieldOp>(loc, codegen.redVal); 1188 codegen.redVal = loop->getResult(0); 1189 } 1190 } 1191 rewriter.setInsertionPointAfter(loop); 1192 } 1193 1194 // Wrap-up loop sequence. 1195 codegen.curVecLength = 1; 1196 genReductionEnd(merger, codegen, rewriter, op); 1197 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/false); 1198 codegen.loops[idx] = Value(); 1199 } 1200 1201 /// Converts the result computed by the sparse kernel into the required form. 1202 static void genResult(Merger &merger, CodeGen &codegen, 1203 PatternRewriter &rewriter, linalg::GenericOp op) { 1204 Location loc = op.getLoc(); 1205 OpOperand *lhs = op.getOutputOperand(0); 1206 Type resType = lhs->get().getType(); 1207 unsigned tensor = lhs->getOperandNumber(); 1208 auto map = op.getTiedIndexingMap(lhs); 1209 auto enc = getSparseTensorEncoding(resType); 1210 Value result = codegen.buffers.back(); // value array 1211 if (enc) { 1212 // The sparse annotation unambigiously defines the arrays needed 1213 // to "reconstruct" the sparse tensor from the storage scheme 1214 // (even though lowering should never need this eventually). 1215 SmallVector<Value, 4> args; 1216 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 1217 AffineExpr a = map.getResult(perm(enc, d)); 1218 if (a.getKind() != AffineExprKind::DimId) 1219 continue; // compound 1220 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 1221 if (merger.isDim(tensor, idx, Dim::kSparse)) { 1222 args.push_back(codegen.pointers[tensor][idx]); 1223 args.push_back(codegen.indices[tensor][idx]); 1224 } 1225 } 1226 args.push_back(result); 1227 result = rewriter.create<ToTensorOp>(loc, resType, args); 1228 } else { 1229 // To "reconstruct" an non-annotated tensor, sipmly load it 1230 // from the bufferized value. 1231 result = rewriter.create<memref::TensorLoadOp>(loc, resType, result); 1232 } 1233 rewriter.replaceOp(op, result); 1234 } 1235 1236 namespace { 1237 1238 /// Sparse rewriting rule for generic Lingalg operation. 1239 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> { 1240 public: 1241 GenericOpSparsifier(MLIRContext *context, SparsificationOptions o) 1242 : OpRewritePattern<linalg::GenericOp>(context), options(o) {} 1243 1244 LogicalResult matchAndRewrite(linalg::GenericOp op, 1245 PatternRewriter &rewriter) const override { 1246 // Detects sparse annotations and translate the per-dimension sparsity 1247 // information for all tensors to loop indices in the kernel. 1248 assert(op.getNumOutputs() == 1); 1249 unsigned numTensors = op.getNumInputsAndOutputs(); 1250 unsigned numLoops = op.iterator_types().getValue().size(); 1251 Merger merger(numTensors, numLoops); 1252 if (!findSparseAnnotations(merger, op)) 1253 return failure(); 1254 1255 // Computes a topologically sorted iteration graph to ensure 1256 // tensors are visited in natural index order. Fails on cycles. 1257 // This assumes that higher-level passes have already put the 1258 // tensors in each tensor expression in a feasible order. 1259 std::vector<unsigned> topSort; 1260 if (!computeIterationGraph(merger, op, topSort, 1261 SortMask::kIncludeUndef | 1262 SortMask::kIncludeDense) && 1263 !computeIterationGraph(merger, op, topSort, SortMask::kIncludeUndef) && 1264 !computeIterationGraph(merger, op, topSort, SortMask::kIncludeDense) && 1265 !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly)) 1266 return failure(); 1267 1268 // Builds the tensor expression for the Linalg operation in SSA form. 1269 Optional<unsigned> exp = merger.buildTensorExpFromLinalg(op); 1270 if (!exp.hasValue()) 1271 return failure(); 1272 1273 // Rejects an inadmissable tensor expression. 1274 if (!isAdmissableTensorExp(merger, op, exp.getValue())) 1275 return failure(); 1276 1277 // Recursively generates code. 1278 CodeGen codegen(options, numTensors, numLoops); 1279 if (!genBuffers(merger, codegen, rewriter, op)) 1280 return failure(); // could not bufferize 1281 genStmt(merger, codegen, rewriter, op, topSort, exp.getValue(), 0); 1282 genResult(merger, codegen, rewriter, op); 1283 return success(); 1284 } 1285 1286 private: 1287 /// Options to control sparse code generation. 1288 SparsificationOptions options; 1289 }; 1290 1291 } // namespace 1292 1293 /// Populates the given patterns list with rewriting rules required for 1294 /// the sparsification of linear algebra operations. 1295 void mlir::populateSparsificationPatterns( 1296 RewritePatternSet &patterns, const SparsificationOptions &options) { 1297 patterns.add<GenericOpSparsifier>(patterns.getContext(), options); 1298 } 1299