1 //===- Sparsification.cpp - Implementation of sparsification --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements converting sparse tensor types to actual sparse code. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 16 #include "mlir/Dialect/Linalg/Utils/Utils.h" 17 #include "mlir/Dialect/MemRef/IR/MemRef.h" 18 #include "mlir/Dialect/SCF/SCF.h" 19 #include "mlir/Dialect/SCF/Transforms.h" 20 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 21 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 22 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 23 #include "mlir/Dialect/StandardOps/IR/Ops.h" 24 #include "mlir/Dialect/Vector/VectorOps.h" 25 #include "mlir/IR/Matchers.h" 26 #include "mlir/IR/TensorEncoding.h" 27 #include "llvm/ADT/SmallBitVector.h" 28 29 using namespace mlir; 30 using namespace mlir::sparse_tensor; 31 32 //===----------------------------------------------------------------------===// 33 // Declarations of data structures. 34 //===----------------------------------------------------------------------===// 35 36 namespace { 37 38 // Iteration graph sorting. 39 enum SortMask { kSparseOnly = 0x0, kIncludeDense = 0x1, kIncludeUndef = 0x2 }; 40 41 // Reduction kinds. 42 enum Reduction { kSum, kProduct, kAnd, kOr, kXor }; 43 44 // Code generation. 45 struct CodeGen { 46 CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops) 47 : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors), 48 pointers(numTensors, std::vector<Value>(numLoops)), 49 indices(numTensors, std::vector<Value>(numLoops)), 50 highs(numTensors, std::vector<Value>(numLoops)), 51 pidxs(numTensors, std::vector<Value>(numLoops)), 52 idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(), 53 curVecLength(1), curVecMask() {} 54 /// Sparsification options. 55 SparsificationOptions options; 56 /// Universal dense indices and upper bounds (by index). The loops array 57 /// is updated with the value of the universal dense index in the current 58 /// loop. The sizes array is set once with the inferred dimension sizes. 59 std::vector<Value> loops; 60 std::vector<Value> sizes; 61 /// Buffers for storing dense and sparse numerical values (by tensor). 62 /// This array is set once during bufferization of all tensors. 63 std::vector<Value> buffers; 64 /// Sparse storage schemes (1-D): pointers and indices (by tensor and index). 65 /// This array is set once during bufferization of all sparse tensors. 66 std::vector<std::vector<Value>> pointers; 67 std::vector<std::vector<Value>> indices; 68 /// Sparse iteration information (by tensor and index). These arrays 69 /// are updated to remain current within the current loop. 70 std::vector<std::vector<Value>> highs; 71 std::vector<std::vector<Value>> pidxs; 72 std::vector<std::vector<Value>> idxs; 73 /// Current reduction, updated during code generation. When indices of a 74 /// reduction are exhausted, all inner loops can "scalarize" the reduction. 75 // TODO: currently only done for (a chain of) innermost for-loops, where it 76 // is most effective; we could generalize to more outer and while-loops. 77 unsigned redExp; 78 Value redVal; 79 Reduction redKind; 80 // Current vector length and mask. 81 unsigned curVecLength; 82 Value curVecMask; 83 }; 84 85 } // namespace 86 87 //===----------------------------------------------------------------------===// 88 // Sparse compiler analysis methods. 89 //===----------------------------------------------------------------------===// 90 91 /// Helper method to apply dimension ordering permutation. 92 static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d) { 93 if (enc) { 94 auto order = enc.getDimOrdering(); 95 if (order) { 96 assert(order.isPermutation()); 97 return order.getDimPosition(d); 98 } 99 } 100 return d; 101 } 102 103 /// Helper method to translate dim level type to internal representation. 104 static Dim toDim(const SparseTensorEncodingAttr &enc, unsigned d) { 105 if (enc) { 106 SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d]; 107 if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed) 108 return Dim::kSparse; 109 if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton) 110 return Dim::kSingle; 111 } 112 return Dim::kDense; 113 } 114 115 /// Helper method to inspect affine expressions. Rejects cases where the 116 /// same index is used more than once. Also rejects affine expressions 117 /// that are not a direct index for annotated tensors. 118 // TODO: accept more affine cases for sparse tensors 119 static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a, Dim dim, 120 bool isDense) { 121 switch (a.getKind()) { 122 case AffineExprKind::DimId: { 123 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 124 if (!merger.isDim(tensor, idx, Dim::kUndef)) 125 return false; // used more than once 126 merger.setDim(tensor, idx, dim); 127 return true; 128 } 129 case AffineExprKind::Add: 130 case AffineExprKind::Mul: { 131 if (!isDense) 132 return false; 133 auto binOp = a.cast<AffineBinaryOpExpr>(); 134 return findAffine(merger, tensor, binOp.getLHS(), dim, isDense) && 135 findAffine(merger, tensor, binOp.getRHS(), dim, isDense); 136 } 137 case AffineExprKind::Constant: 138 return isDense; 139 default: 140 return false; 141 } 142 } 143 144 /// Helper method to inspect sparse encodings in the tensor types. 145 /// Fills the per-dimension sparsity information for all tensors. 146 /// Returns true if the sparse annotations and affine subscript 147 /// expressions of all tensors are admissable. Returns false if 148 /// no annotations are found or inadmissable constructs occur. 149 static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) { 150 bool annotated = false; 151 for (OpOperand *t : op.getInputAndOutputOperands()) { 152 auto map = op.getTiedIndexingMap(t); 153 auto enc = getSparseTensorEncoding(t->get().getType()); 154 if (enc) 155 annotated = true; 156 assert(map.getNumResults() == op.getRank(t)); 157 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 158 unsigned tensor = t->getOperandNumber(); 159 AffineExpr a = map.getResult(perm(enc, d)); 160 if (!findAffine(merger, tensor, a, toDim(enc, d), !enc)) 161 return false; // inadmissable affine expression 162 } 163 } 164 return annotated; 165 } 166 167 /// A DFS helper to compute a topological sort. Note that recursion is 168 /// bounded by the number of implicit loops, which is always small. 169 /// Returns false when a cycle is detected. 170 static bool topSortDFS(unsigned i, std::vector<unsigned> &visit, 171 std::vector<unsigned> &topSort, 172 std::vector<std::vector<bool>> &adjM) { 173 if (visit[i] != 0) 174 return visit[i] != 1; // 1 denotes cycle! 175 visit[i] = 1; 176 for (unsigned j = 0, e = visit.size(); j < e; j++) 177 if (adjM[i][j]) 178 if (!topSortDFS(j, visit, topSort, adjM)) 179 return false; 180 visit[i] = 2; 181 topSort.push_back(i); 182 return true; 183 } 184 185 /// Helper method to add all constraints from the indices in one affine 186 /// expression before all indices in the other affine expression. For 187 /// example i0+i1 < i2+i3+1 yields i0<i2, i0<i3, i1<i2, and i1<i3. 188 static void addAffineOrderings(std::vector<std::vector<bool>> &adjM, 189 AffineExpr a, AffineExpr b, unsigned fidx) { 190 switch (a.getKind()) { 191 case AffineExprKind::DimId: { 192 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 193 if (b) 194 addAffineOrderings(adjM, b, AffineExpr(), idx); 195 else 196 adjM[fidx][idx] = true; 197 break; 198 } 199 case AffineExprKind::Add: 200 case AffineExprKind::Mul: { 201 auto binOp = a.cast<AffineBinaryOpExpr>(); 202 addAffineOrderings(adjM, binOp.getLHS(), b, fidx); 203 addAffineOrderings(adjM, binOp.getRHS(), b, fidx); 204 break; 205 } 206 default: 207 break; 208 } 209 } 210 211 /// Computes a topologically sorted iteration graph for the linalg operation. 212 /// Ensures all tensors are visited in natural index order. This is essential 213 /// for sparse storage formats since these only support access along fixed 214 /// dimensions. Even for dense storage formats, however, the natural index 215 /// order yields innermost unit-stride access with better spatial locality. 216 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op, 217 std::vector<unsigned> &topSort, 218 unsigned mask) { 219 // Set up an n x n from/to adjacency matrix of the iteration graph 220 // for the implicit loop indices i_0 .. i_n-1. 221 unsigned n = op.getNumLoops(); 222 std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false)); 223 224 // Iterate over the indexing maps of every tensor in the tensor expression. 225 for (OpOperand *t : op.getInputAndOutputOperands()) { 226 auto map = op.getTiedIndexingMap(t); 227 auto enc = getSparseTensorEncoding(t->get().getType()); 228 assert(map.getNumDims() == n); 229 // Skip dense tensor constraints when not requested. 230 if (!(mask & SortMask::kIncludeDense) && !enc) 231 continue; 232 // Each tensor expression and optional dimension ordering (row-major 233 // by default) puts an ordering constraint on the loop indices. For 234 // example, the tensor expresion A_ijk forces the ordering i < j < k 235 // on the loop indices if no explicit dimension ordering is given. 236 for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) { 237 AffineExpr f = map.getResult(perm(enc, d - 1)); 238 AffineExpr t = map.getResult(perm(enc, d)); 239 addAffineOrderings(adjM, f, t, 0); 240 } 241 // Push unrelated loops into sparse iteration space, so these 242 // will be skipped more often. 243 if (mask & SortMask::kIncludeUndef) { 244 unsigned tensor = t->getOperandNumber(); 245 for (unsigned i = 0; i < n; i++) 246 if (merger.isDim(tensor, i, Dim::kSparse)) 247 for (unsigned j = 0; j < n; j++) 248 if (merger.isDim(tensor, j, Dim::kUndef)) 249 adjM[i][j] = true; 250 } 251 } 252 253 // Topologically sort the iteration graph to determine loop order. 254 // Report failure for a cyclic iteration graph. 255 topSort.clear(); 256 topSort.reserve(n); 257 std::vector<unsigned> visit(n, 0); 258 for (unsigned i = 0; i < n; i++) 259 if (visit[i] == 0) 260 if (!topSortDFS(i, visit, topSort, adjM)) 261 return false; // cycle! 262 std::reverse(std::begin(topSort), std::end(topSort)); 263 return true; 264 } 265 266 /// Returns true if tensor has an in-place annotation. 267 static bool isInPlace(Value val) { 268 if (auto arg = val.dyn_cast<BlockArgument>()) 269 if (auto funcOp = dyn_cast<FuncOp>(arg.getOwner()->getParentOp())) 270 if (auto attr = funcOp.getArgAttrOfType<BoolAttr>( 271 arg.getArgNumber(), linalg::LinalgDialect::kInplaceableAttrName)) 272 return attr.getValue(); 273 return false; 274 } 275 276 /// Returns true if tensor materializes into the computation. 277 static bool isMaterializing(Value val) { 278 return val.getDefiningOp<linalg::InitTensorOp>() || 279 val.getDefiningOp<InitOp>(); 280 } 281 282 /// Returns true when the tensor expression is admissable for codegen. 283 /// Since all sparse input tensors are admissable, we just need to check 284 /// whether the output tensor in the tensor expression codegen is admissable. 285 static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op, 286 unsigned exp) { 287 OpOperand *lhs = op.getOutputOperand(0); 288 unsigned tensor = lhs->getOperandNumber(); 289 auto enc = getSparseTensorEncoding(lhs->get().getType()); 290 // An non-annotated output tensor is assumed dense, and becomes a random 291 // access n-dim memref. Admissable since insertions cannot occur. 292 if (!enc) 293 return true; 294 // An all-dense annotated "sparse" output tensor becomes a linearized random 295 // access 1-dim memref. Also admissable since insertions cannot occur. 296 bool allDense = true; 297 unsigned numLoops = op.iterator_types().getValue().size(); 298 for (unsigned i = 0; i < numLoops; i++) 299 if (merger.isDim(tensor, i, Dim::kSparse)) { 300 allDense = false; 301 break; 302 } 303 if (allDense) 304 return true; 305 // A tensor expression with a sparse output tensor that changes its values 306 // but not its nonzero structure, an operation called "simply dynamic" in 307 // [Bik96,Ch9], is also admissable without special codegen, provided 308 // the tensor's underlying sparse storage scheme can be modified in place. 309 if (merger.isConjunction(tensor, exp)) 310 return isInPlace(lhs->get()); 311 // Reject for now since this requires changes to the nonzero structure. 312 // TODO: implement "workspaces" [Kjolstad2019] 313 return false; 314 } 315 316 //===----------------------------------------------------------------------===// 317 // Sparse compiler synthesis methods (statements and expressions). 318 //===----------------------------------------------------------------------===// 319 320 /// Maps reduction kind to name encoding. 321 static StringRef getReductionName(Reduction kind) { 322 switch (kind) { 323 case kSum: 324 return "add"; 325 case kProduct: 326 return "mul"; 327 case kAnd: 328 return "and"; 329 case kOr: 330 return "or"; 331 case kXor: 332 return "xor"; 333 } 334 llvm_unreachable("unknown reduction kind"); 335 } 336 337 /// Maps operation to reduction. 338 static Reduction getReduction(Kind kind) { 339 switch (kind) { 340 case Kind::kAddF: 341 case Kind::kAddI: 342 case Kind::kSubF: 343 case Kind::kSubI: 344 return kSum; 345 case Kind::kMulF: 346 case Kind::kMulI: 347 return kProduct; 348 case Kind::kAndI: 349 return kAnd; 350 case Kind::kOrI: 351 return kOr; 352 case Kind::kXorI: 353 return kXor; 354 default: 355 llvm_unreachable("unexpected reduction operator"); 356 } 357 } 358 359 /// Generates an initial value for a vector reductions, following the scheme 360 /// given in Chapter 5 of "The Software Vectorization Handbook", where the 361 /// initial scalar value is correctly embedded in the vector reduction value, 362 /// and a straightforward horizontal reduction will complete the operation. 363 static Value genReductionInit(PatternRewriter &rewriter, Location loc, 364 Reduction kind, VectorType vtp, Value r) { 365 switch (kind) { 366 case kSum: 367 case kXor: { 368 // Initialize reduction vector to: | 0 | .. | 0 | r | 369 Attribute zero = rewriter.getZeroAttr(vtp); 370 Value vec = rewriter.create<arith::ConstantOp>(loc, vtp, zero); 371 return rewriter.create<vector::InsertElementOp>(loc, r, vec, 0); 372 } 373 case kProduct: { 374 // Initialize reduction vector to: | 1 | .. | 1 | r | 375 Type etp = vtp.getElementType(); 376 Attribute one; 377 if (etp.isa<FloatType>()) 378 one = rewriter.getFloatAttr(etp, 1.0); 379 else 380 one = rewriter.getIntegerAttr(etp, 1); 381 Value vec = rewriter.create<arith::ConstantOp>( 382 loc, vtp, DenseElementsAttr::get(vtp, one)); 383 return rewriter.create<vector::InsertElementOp>(loc, r, vec, 0); 384 } 385 case kAnd: 386 case kOr: 387 // Initialize reduction vector to: | r | .. | r | r | 388 return rewriter.create<vector::BroadcastOp>(loc, vtp, r); 389 } 390 llvm_unreachable("unknown reduction kind"); 391 } 392 393 /// Maps sparse integer option to actual integral storage type. 394 static Type genIntType(PatternRewriter &rewriter, unsigned width) { 395 if (width == 0) 396 return rewriter.getIndexType(); 397 return rewriter.getIntegerType(width); 398 } 399 400 /// Generates buffer for the output tensor. Note that all sparse kernels 401 /// assume that when all elements are written to (viz. x(i) = y(i) * z(i)), 402 /// the output buffer is already initialized to all zeroes and only nonzeroes 403 /// values are computed and written out. For updates (viz. x(i) += y(i) * z(i)), 404 /// only nonzeroes values are used for the updates and no assumption on the 405 /// original contents of the output buffer is necessary.. 406 static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter, 407 linalg::GenericOp op, MemRefType denseTp, 408 ArrayRef<Value> args) { 409 Location loc = op.getLoc(); 410 Value tensor = op.getOutputOperand(0)->get(); 411 // The output tensor simply could materialize from the buffer that will 412 // be generated for the tensor present in the outs() clause. This has 413 // the major advantage that the sparse kernel only updates the nonzero 414 // positions for the output tensor. 415 if (isInPlace(tensor)) 416 return rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 417 // By default, a new buffer is allocated which is initialized to the 418 // tensor defined in the outs() clause. This is always correct but 419 // introduces a dense initialization component that may negatively 420 // impact the running complexity of the sparse kernel. If the tensor 421 // materializes into the computation, we need to preserve the zero 422 // initialization assumption of all sparse output buffers. 423 if (isMaterializing(tensor)) { 424 Type tp = denseTp.getElementType(); 425 Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args); 426 Value zero = 427 rewriter.create<arith::ConstantOp>(loc, tp, rewriter.getZeroAttr(tp)); 428 rewriter.create<linalg::FillOp>(loc, zero, alloc); 429 return alloc; 430 } 431 Value init = rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 432 Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args); 433 rewriter.create<memref::CopyOp>(loc, init, alloc); 434 return alloc; 435 } 436 437 /// Local bufferization of all dense and sparse data structures. 438 /// This code enables testing the first prototype sparse compiler. 439 // TODO: replace this with a proliferated bufferization strategy 440 static void genBuffers(Merger &merger, CodeGen &codegen, 441 PatternRewriter &rewriter, linalg::GenericOp op) { 442 Location loc = op.getLoc(); 443 assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1); 444 // For every tensor, find lower and upper bound on dimensions, set the 445 // same bounds on loop indices, and obtain dense or sparse buffer(s). 446 SmallVector<Value, 4> args; 447 for (OpOperand *t : op.getInputAndOutputOperands()) { 448 unsigned tensor = t->getOperandNumber(); 449 auto shape = op.getShape(t); 450 auto map = op.getTiedIndexingMap(t); 451 auto enc = getSparseTensorEncoding(t->get().getType()); 452 // Scan all dimensions of current tensor. 453 args.clear(); 454 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 455 AffineExpr a = map.getResult(perm(enc, d)); 456 if (a.getKind() != AffineExprKind::DimId) 457 continue; // compound 458 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 459 // Handle sparse storage schemes. 460 if (merger.isDim(tensor, idx, Dim::kSparse)) { 461 auto dynShape = {ShapedType::kDynamicSize}; 462 auto ptrTp = MemRefType::get( 463 dynShape, genIntType(rewriter, enc.getPointerBitWidth())); 464 auto indTp = MemRefType::get( 465 dynShape, genIntType(rewriter, enc.getIndexBitWidth())); 466 Value dim = rewriter.create<arith::ConstantIndexOp>(loc, d); 467 // Generate sparse primitives to obtains pointer and indices. 468 codegen.pointers[tensor][idx] = 469 rewriter.create<ToPointersOp>(loc, ptrTp, t->get(), dim); 470 codegen.indices[tensor][idx] = 471 rewriter.create<ToIndicesOp>(loc, indTp, t->get(), dim); 472 } 473 // Find upper bound in current dimension. 474 unsigned p = perm(enc, d); 475 Value up = linalg::createOrFoldDimOp(rewriter, loc, t->get(), p); 476 if (shape[p] == MemRefType::kDynamicSize) 477 args.push_back(up); 478 assert(codegen.highs[tensor][idx] == nullptr); 479 codegen.sizes[idx] = codegen.highs[tensor][idx] = up; 480 } 481 // Perform the required bufferization. Dense inputs materialize 482 // from the input tensors. Dense outputs need special handling. 483 // Sparse inputs use sparse primitives to obtain the values. 484 // We also accept in-place all-dense annotated "sparse" outputs. 485 Type elementType = getElementTypeOrSelf(t->get().getType()); 486 if (!enc) { 487 // Non-annotated dense tensors. 488 auto denseTp = MemRefType::get(shape, elementType); 489 if (tensor < op.getNumInputs()) 490 codegen.buffers[tensor] = 491 rewriter.create<memref::BufferCastOp>(loc, denseTp, t->get()); 492 else 493 codegen.buffers[tensor] = 494 genOutputBuffer(codegen, rewriter, op, denseTp, args); 495 } else { 496 // Annotated sparse tensors. 497 auto dynShape = {ShapedType::kDynamicSize}; 498 auto sparseTp = MemRefType::get(dynShape, elementType); 499 codegen.buffers[tensor] = 500 rewriter.create<ToValuesOp>(loc, sparseTp, t->get()); 501 } 502 } 503 } 504 505 /// Constructs vector type. 506 static VectorType vectorType(CodeGen &codegen, Type etp) { 507 return VectorType::get(codegen.curVecLength, etp); 508 } 509 510 /// Constructs vector type from pointer. 511 static VectorType vectorType(CodeGen &codegen, Value ptr) { 512 return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType()); 513 } 514 515 /// Constructs vector iteration mask. 516 static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter, 517 Value iv, Value lo, Value hi, Value step) { 518 Location loc = iv.getLoc(); 519 VectorType mtp = vectorType(codegen, rewriter.getIntegerType(1)); 520 // Special case if the vector length evenly divides the trip count (for 521 // example, "for i = 0, 128, 16"). A constant all-true mask is generated 522 // so that all subsequent masked memory operations are immediately folded 523 // into unconditional memory operations. 524 IntegerAttr loInt, hiInt, stepInt; 525 if (matchPattern(lo, m_Constant(&loInt)) && 526 matchPattern(hi, m_Constant(&hiInt)) && 527 matchPattern(step, m_Constant(&stepInt))) { 528 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) 529 return rewriter.create<vector::BroadcastOp>( 530 loc, mtp, rewriter.create<arith::ConstantIntOp>(loc, 1, 1)); 531 } 532 // Otherwise, generate a vector mask that avoids overrunning the upperbound 533 // during vector execution. Here we rely on subsequent loop optimizations to 534 // avoid executing the mask in all iterations, for example, by splitting the 535 // loop into an unconditional vector loop and a scalar cleanup loop. 536 auto minMap = AffineMap::get( 537 /*dimCount=*/2, /*symbolCount=*/1, 538 {rewriter.getAffineSymbolExpr(0), 539 rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)}, 540 rewriter.getContext()); 541 Value end = 542 rewriter.createOrFold<AffineMinOp>(loc, minMap, ValueRange{hi, iv, step}); 543 return rewriter.create<vector::CreateMaskOp>(loc, mtp, end); 544 } 545 546 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi]. 547 static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter, 548 Value ptr, ArrayRef<Value> args) { 549 Location loc = ptr.getLoc(); 550 VectorType vtp = vectorType(codegen, ptr); 551 Value pass = 552 rewriter.create<arith::ConstantOp>(loc, vtp, rewriter.getZeroAttr(vtp)); 553 if (args.back().getType().isa<VectorType>()) { 554 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 555 Value indexVec = args.back(); 556 scalarArgs.back() = rewriter.create<arith::ConstantIndexOp>(loc, 0); 557 return rewriter.create<vector::GatherOp>( 558 loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass); 559 } 560 return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args, 561 codegen.curVecMask, pass); 562 } 563 564 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs. 565 static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter, 566 Value rhs, Value ptr, ArrayRef<Value> args) { 567 Location loc = ptr.getLoc(); 568 if (args.back().getType().isa<VectorType>()) { 569 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 570 Value indexVec = args.back(); 571 scalarArgs.back() = rewriter.create<arith::ConstantIndexOp>(loc, 0); 572 rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, 573 codegen.curVecMask, rhs); 574 return; 575 } 576 rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask, 577 rhs); 578 } 579 580 /// Generates a vectorized invariant. Here we rely on subsequent loop 581 /// optimizations to hoist the invariant broadcast out of the vector loop. 582 static Value genVectorInvariantValue(CodeGen &codegen, 583 PatternRewriter &rewriter, Value val) { 584 VectorType vtp = vectorType(codegen, val.getType()); 585 return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val); 586 } 587 588 /// Generates an affine expression. 589 // 590 // TODO: generalize for sparse tensor subscripts 591 // 592 static Value genAffine(CodeGen &codegen, PatternRewriter &rewriter, 593 AffineExpr a, Location loc) { 594 switch (a.getKind()) { 595 case AffineExprKind::DimId: { 596 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 597 return codegen.loops[idx]; // universal dense index 598 } 599 case AffineExprKind::Add: { 600 auto binOp = a.cast<AffineBinaryOpExpr>(); 601 return rewriter.create<arith::AddIOp>( 602 loc, genAffine(codegen, rewriter, binOp.getLHS(), loc), 603 genAffine(codegen, rewriter, binOp.getRHS(), loc)); 604 } 605 case AffineExprKind::Mul: { 606 auto binOp = a.cast<AffineBinaryOpExpr>(); 607 return rewriter.create<arith::MulIOp>( 608 loc, genAffine(codegen, rewriter, binOp.getLHS(), loc), 609 genAffine(codegen, rewriter, binOp.getRHS(), loc)); 610 } 611 case AffineExprKind::Constant: { 612 int64_t c = a.cast<AffineConstantExpr>().getValue(); 613 return rewriter.create<arith::ConstantIndexOp>(loc, c); 614 } 615 default: 616 llvm_unreachable("unexpected affine subscript"); 617 } 618 } 619 620 /// Generates subscript for load/store on a dense or sparse tensor. 621 static Value genSubscript(CodeGen &codegen, PatternRewriter &rewriter, 622 linalg::GenericOp op, OpOperand *t, 623 SmallVector<Value, 4> &args) { 624 unsigned tensor = t->getOperandNumber(); 625 auto map = op.getTiedIndexingMap(t); 626 auto enc = getSparseTensorEncoding(t->get().getType()); 627 unsigned rank = map.getNumResults(); 628 if (enc) { 629 // Note that currently, all sparse subscripts are simple. 630 // TODO: accept affine too? 631 AffineExpr a = map.getResult(perm(enc, rank - 1)); 632 assert(a.getKind() == AffineExprKind::DimId); 633 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 634 assert(codegen.pidxs[tensor][idx] != nullptr); 635 args.push_back(codegen.pidxs[tensor][idx]); // position index 636 } else { 637 for (unsigned d = 0; d < rank; d++) { 638 AffineExpr a = map.getResult(perm(enc, d)); 639 args.push_back(genAffine(codegen, rewriter, a, op.getLoc())); 640 } 641 } 642 return codegen.buffers[tensor]; 643 } 644 645 /// Generates a load on a dense or sparse tensor. 646 static Value genTensorLoad(Merger &merger, CodeGen &codegen, 647 PatternRewriter &rewriter, linalg::GenericOp op, 648 unsigned exp) { 649 // Test if the load was hoisted to a higher loop nest. 650 Value val = merger.exp(exp).val; 651 if (val) { 652 if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>()) 653 return genVectorInvariantValue(codegen, rewriter, val); 654 return val; 655 } 656 // Actual load. 657 SmallVector<Value, 4> args; 658 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 659 Value ptr = genSubscript(codegen, rewriter, op, t, args); 660 if (codegen.curVecLength > 1) 661 return genVectorLoad(codegen, rewriter, ptr, args); 662 return rewriter.create<memref::LoadOp>(op.getLoc(), ptr, args); 663 } 664 665 /// Generates a store on a dense or sparse tensor. 666 static void genTensorStore(Merger &merger, CodeGen &codegen, 667 PatternRewriter &rewriter, linalg::GenericOp op, 668 Value rhs) { 669 // Test if this is a scalarized reduction. 670 if (codegen.redVal) { 671 if (codegen.curVecLength > 1) 672 rhs = rewriter.create<SelectOp>(op.getLoc(), codegen.curVecMask, rhs, 673 codegen.redVal); 674 codegen.redVal = rhs; 675 return; 676 } 677 // Actual store. 678 SmallVector<Value, 4> args; 679 OpOperand *t = op.getOutputOperand(0); 680 Value ptr = genSubscript(codegen, rewriter, op, t, args); 681 if (codegen.curVecLength > 1) 682 genVectorStore(codegen, rewriter, rhs, ptr, args); 683 else 684 rewriter.create<memref::StoreOp>(op.getLoc(), rhs, ptr, args); 685 } 686 687 /// Generates a pointer/index load from the sparse storage scheme. Narrower 688 /// data types need to be zero extended before casting the value into the 689 /// index type used for looping and indexing. 690 static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc, 691 Value ptr, Value s) { 692 // See https://llvm.org/docs/GetElementPtr.html for some background on 693 // the complications described below. 694 if (codegen.curVecLength > 1) { 695 // Since the index vector is used in a subsequent gather/scatter operations, 696 // which effectively defines an unsigned pointer + signed index, we must 697 // zero extend the vector to an index width. For 8-bit and 16-bit values, 698 // an 32-bit index width suffices. For 32-bit values, zero extending the 699 // elements into 64-bit loses some performance since the 32-bit indexed 700 // gather/scatter is more efficient than the 64-bit index variant (if the 701 // negative 32-bit index space is unused, the enableSIMDIndex32 flag can 702 // preserve this performance). For 64-bit values, there is no good way 703 // to state that the indices are unsigned, with creates the potential of 704 // incorrect address calculations in the unlikely case we need such 705 // extremely large offsets. 706 Type etp = ptr.getType().cast<MemRefType>().getElementType(); 707 Value vload = genVectorLoad(codegen, rewriter, ptr, {s}); 708 if (!etp.isa<IndexType>()) { 709 if (etp.getIntOrFloatBitWidth() < 32) 710 vload = rewriter.create<arith::ExtUIOp>( 711 loc, vload, vectorType(codegen, rewriter.getIntegerType(32))); 712 else if (etp.getIntOrFloatBitWidth() < 64 && 713 !codegen.options.enableSIMDIndex32) 714 vload = rewriter.create<arith::ExtUIOp>( 715 loc, vload, vectorType(codegen, rewriter.getIntegerType(64))); 716 } 717 return vload; 718 } 719 // For the scalar case, we simply zero extend narrower indices into 64-bit 720 // values before casting to index without a performance penalty. Here too, 721 // however, indices that already are 64-bit, in theory, cannot express the 722 // full range as explained above. 723 Value load = rewriter.create<memref::LoadOp>(loc, ptr, s); 724 if (!load.getType().isa<IndexType>()) { 725 if (load.getType().getIntOrFloatBitWidth() < 64) 726 load = rewriter.create<arith::ExtUIOp>(loc, load, 727 rewriter.getIntegerType(64)); 728 load = 729 rewriter.create<arith::IndexCastOp>(loc, load, rewriter.getIndexType()); 730 } 731 return load; 732 } 733 734 /// Generates an invariant value. 735 static Value genInvariantValue(Merger &merger, CodeGen &codegen, 736 PatternRewriter &rewriter, unsigned exp) { 737 Value val = merger.exp(exp).val; 738 if (codegen.curVecLength > 1) 739 return genVectorInvariantValue(codegen, rewriter, val); 740 return val; 741 } 742 743 /// Generates an address computation "sz * p + i". 744 static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter, 745 Location loc, Value size, Value p, Value i) { 746 Value mul = rewriter.create<arith::MulIOp>(loc, size, p); 747 if (auto vtp = i.getType().dyn_cast<VectorType>()) { 748 Value inv = 749 rewriter.create<arith::IndexCastOp>(loc, mul, vtp.getElementType()); 750 mul = genVectorInvariantValue(codegen, rewriter, inv); 751 } 752 return rewriter.create<arith::AddIOp>(loc, mul, i); 753 } 754 755 /// Generates start of a reduction. 756 static Value genReductionStart(Merger &merger, CodeGen &codegen, 757 PatternRewriter &rewriter, 758 linalg::GenericOp op) { 759 if (codegen.redVal) 760 return codegen.redVal; // chained with previous for-loop 761 // Generate vector or scalar start of a reduction. 762 unsigned vl = codegen.curVecLength; 763 if (vl > 1) { 764 VectorType vtp = vectorType(codegen, codegen.buffers[codegen.redExp]); 765 assert(!merger.exp(codegen.redExp).val); 766 codegen.curVecLength = 1; 767 Value load = genTensorLoad(merger, codegen, rewriter, op, codegen.redExp); 768 codegen.curVecLength = vl; 769 return genReductionInit(rewriter, op.getLoc(), codegen.redKind, vtp, load); 770 } 771 return genTensorLoad(merger, codegen, rewriter, op, codegen.redExp); 772 } 773 774 /// Generates end of a reduction. 775 static void genReductionEnd(Merger &merger, CodeGen &codegen, 776 PatternRewriter &rewriter, linalg::GenericOp op) { 777 Value red = codegen.redVal; 778 if (!red) 779 return; 780 assert(codegen.curVecLength == 1); 781 codegen.redVal = merger.exp(codegen.redExp).val = Value(); // end chain 782 // Generate vector or scalar end of a reduction. 783 if (auto vtp = red.getType().dyn_cast<VectorType>()) { 784 StringRef name = getReductionName(codegen.redKind); 785 StringAttr kind = rewriter.getStringAttr(name); 786 red = rewriter.create<vector::ReductionOp>( 787 op.getLoc(), vtp.getElementType(), kind, red, ValueRange{}); 788 } 789 genTensorStore(merger, codegen, rewriter, op, red); 790 } 791 792 /// Recursively generates tensor expression. 793 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 794 linalg::GenericOp op, unsigned exp) { 795 Location loc = op.getLoc(); 796 if (exp == -1u) 797 return Value(); 798 if (merger.exp(exp).kind == Kind::kTensor) 799 return genTensorLoad(merger, codegen, rewriter, op, exp); 800 if (merger.exp(exp).kind == Kind::kInvariant) 801 return genInvariantValue(merger, codegen, rewriter, exp); 802 Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0); 803 Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1); 804 return merger.buildExp(rewriter, loc, exp, v0, v1); 805 } 806 807 /// Determines if affine expression is invariant. 808 static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a, 809 unsigned ldx, bool &atLevel) { 810 switch (a.getKind()) { 811 case AffineExprKind::DimId: { 812 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 813 if (idx == ldx) 814 atLevel = true; 815 return codegen.loops[idx] != nullptr; // no longer in play? 816 } 817 case AffineExprKind::Add: 818 case AffineExprKind::Mul: { 819 auto binOp = a.cast<AffineBinaryOpExpr>(); 820 return isInvariantAffine(codegen, binOp.getLHS(), ldx, atLevel) && 821 isInvariantAffine(codegen, binOp.getRHS(), ldx, atLevel); 822 } 823 default: 824 return true; 825 } 826 } 827 828 /// Hoists loop invariant tensor loads for which indices have been exhausted. 829 static void genInvariants(Merger &merger, CodeGen &codegen, 830 PatternRewriter &rewriter, linalg::GenericOp op, 831 unsigned exp, unsigned ldx, bool hoist, 832 Kind last = Kind::kTensor) { 833 if (exp == -1u) 834 return; 835 if (merger.exp(exp).kind == Kind::kTensor) { 836 // Inspect tensor indices. 837 bool atLevel = ldx == -1u; 838 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 839 auto map = op.getTiedIndexingMap(t); 840 auto enc = getSparseTensorEncoding(t->get().getType()); 841 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 842 AffineExpr a = map.getResult(perm(enc, d)); 843 if (!isInvariantAffine(codegen, a, ldx, atLevel)) 844 return; // still in play 845 } 846 // All exhausted at this level (atLevel denotes exactly at this level). 847 OpOperand *lhs = op.getOutputOperand(0); 848 if (lhs == t) { 849 codegen.redExp = hoist ? exp : -1u; 850 codegen.redKind = getReduction(last); 851 assert(!codegen.redVal); 852 } else if (atLevel) { 853 merger.exp(exp).val = 854 hoist ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value(); 855 } 856 } else if (merger.exp(exp).kind != Kind::kInvariant) { 857 // Traverse into the binary operations. Note that we only hoist 858 // tensor loads, since subsequent MLIR/LLVM passes know how to 859 // deal with all other kinds of derived loop invariants. 860 Kind last = merger.exp(exp).kind; 861 unsigned e0 = merger.exp(exp).children.e0; 862 unsigned e1 = merger.exp(exp).children.e1; 863 genInvariants(merger, codegen, rewriter, op, e0, ldx, hoist, last); 864 genInvariants(merger, codegen, rewriter, op, e1, ldx, hoist, last); 865 } 866 } 867 868 /// Generates initialization code for the subsequent loop sequence at 869 /// current index level. Returns true if the loop sequence needs to 870 /// maintain the universal index. 871 static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 872 linalg::GenericOp op, std::vector<unsigned> &topSort, 873 unsigned at, llvm::BitVector &inits) { 874 bool needsUniv = false; 875 Location loc = op.getLoc(); 876 unsigned idx = topSort[at]; 877 878 // Initialize sparse positions. 879 for (unsigned b = 0, be = inits.size(); b < be; b++) { 880 if (inits[b]) { 881 unsigned tensor = merger.tensor(b); 882 assert(idx == merger.index(b)); 883 if (merger.isDim(b, Dim::kSparse)) { 884 // Initialize sparse index. 885 unsigned pat = at; 886 for (; pat != 0; pat--) { 887 if (codegen.pidxs[tensor][topSort[pat - 1]]) 888 break; 889 } 890 Value ptr = codegen.pointers[tensor][idx]; 891 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 892 Value p0 = (pat == 0) ? rewriter.create<arith::ConstantIndexOp>(loc, 0) 893 : codegen.pidxs[tensor][topSort[pat - 1]]; 894 codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0); 895 Value p1 = rewriter.create<arith::AddIOp>(loc, p0, one); 896 codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1); 897 } else { 898 // Dense index still in play. 899 needsUniv = true; 900 } 901 } 902 } 903 904 // Initialize the universal dense index. 905 codegen.loops[idx] = rewriter.create<arith::ConstantIndexOp>(loc, 0); 906 return needsUniv; 907 } 908 909 /// Returns vectorization strategy. Any implicit inner loop in the Linalg 910 /// operation is a candidate. Whether it is actually converted to SIMD code 911 /// depends on the requested strategy. 912 static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) { 913 switch (codegen.options.vectorizationStrategy) { 914 case SparseVectorizationStrategy::kNone: 915 return false; 916 case SparseVectorizationStrategy::kDenseInnerLoop: 917 return isInner && !isSparse; 918 case SparseVectorizationStrategy::kAnyStorageInnerLoop: 919 return isInner; 920 } 921 llvm_unreachable("unexpected vectorization strategy"); 922 } 923 924 /// Returns parallelization strategy. Any implicit loop in the Linalg operation 925 /// that is marked "parallel" is a candidate. Whether it is actually converted 926 /// to a parallel operation depends on the requested strategy. 927 static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction, 928 bool isSparse, bool isVector) { 929 switch (codegen.options.parallelizationStrategy) { 930 case SparseParallelizationStrategy::kNone: 931 return false; 932 case SparseParallelizationStrategy::kDenseOuterLoop: 933 return isOuter && !isSparse && !isReduction && !isVector; 934 case SparseParallelizationStrategy::kAnyStorageOuterLoop: 935 return isOuter && !isReduction && !isVector; 936 case SparseParallelizationStrategy::kDenseAnyLoop: 937 return !isSparse && !isReduction && !isVector; 938 case SparseParallelizationStrategy::kAnyStorageAnyLoop: 939 return !isReduction && !isVector; 940 } 941 llvm_unreachable("unexpected parallelization strategy"); 942 } 943 944 /// Checks unit stride for dense tensors. The iteration graph may have ignored 945 /// dense access patterns in order to avoid cycles (sparse access patterns are 946 /// always placed innermost), but that means dense access has become strided. 947 /// This prevents effective vectorization. 948 static bool denseUnitStrides(Merger &merger, linalg::GenericOp op, 949 unsigned idx) { 950 for (OpOperand *t : op.getInputAndOutputOperands()) { 951 if (!getSparseTensorEncoding(t->get().getType())) { 952 auto map = op.getTiedIndexingMap(t); 953 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 954 AffineExpr a = map.getResult(d); 955 // Report non-unit stride if innermost index appears at an outer 956 // dimension (true non-unit stride) or if the innermost index appears 957 // in a compound subscript in the innermost dimension. Even if the 958 // latter is unit stride, it does not play well with scatter/gather. 959 // TODO: accept unit stride affine innermost like a[i,j+k+1]? 960 if (a.isFunctionOfDim(idx) && 961 ((d != rank - 1) || (a.getKind() != AffineExprKind::DimId))) 962 return false; 963 } 964 } 965 } 966 return true; 967 } 968 969 /// Generates a for-loop on a single index. 970 static Operation *genFor(Merger &merger, CodeGen &codegen, 971 PatternRewriter &rewriter, linalg::GenericOp op, 972 bool isOuter, bool isInner, unsigned idx, 973 llvm::BitVector &indices) { 974 unsigned fb = indices.find_first(); 975 unsigned tensor = merger.tensor(fb); 976 assert(idx == merger.index(fb)); 977 auto iteratorTypes = op.iterator_types().getValue(); 978 bool isReduction = isReductionIterator(iteratorTypes[idx]); 979 bool isSparse = merger.isDim(fb, Dim::kSparse); 980 bool isVector = isVectorFor(codegen, isInner, isSparse) && 981 denseUnitStrides(merger, op, idx); 982 bool isParallel = 983 isParallelFor(codegen, isOuter, isReduction, isSparse, isVector); 984 985 // Prepare vector length. 986 if (isVector) 987 codegen.curVecLength = codegen.options.vectorLength; 988 989 // Loop bounds and increment. 990 Location loc = op.getLoc(); 991 Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx]; 992 Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx]; 993 Value step = 994 rewriter.create<arith::ConstantIndexOp>(loc, codegen.curVecLength); 995 996 // Emit a parallel loop. 997 if (isParallel) { 998 assert(!isVector); 999 scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step); 1000 if (isSparse) 1001 codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0]; 1002 else 1003 codegen.loops[idx] = parOp.getInductionVars()[0]; 1004 rewriter.setInsertionPointToStart(parOp.getBody()); 1005 return parOp; 1006 } 1007 1008 // Emit a sequential loop, potentially with a scalarized reduction. 1009 bool scalarRed = isInner && codegen.redExp != -1u; 1010 SmallVector<Value, 4> operands; 1011 if (scalarRed) { 1012 Value load = genReductionStart(merger, codegen, rewriter, op); 1013 operands.push_back(load); 1014 } 1015 scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands); 1016 if (scalarRed) { 1017 codegen.redVal = merger.exp(codegen.redExp).val = 1018 forOp.getRegionIterArgs().front(); 1019 } 1020 // Assign induction variable to sparse or dense index. 1021 Value iv = forOp.getInductionVar(); 1022 if (isSparse) 1023 codegen.pidxs[tensor][idx] = iv; 1024 else 1025 codegen.loops[idx] = iv; 1026 rewriter.setInsertionPointToStart(forOp.getBody()); 1027 // Share vector iteration mask between all subsequent loads/stores. 1028 if (isVector) 1029 codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step); 1030 return forOp; 1031 } 1032 1033 /// Emit a while-loop for co-iteration over multiple indices. 1034 static Operation *genWhile(Merger &merger, CodeGen &codegen, 1035 PatternRewriter &rewriter, linalg::GenericOp op, 1036 unsigned idx, bool needsUniv, 1037 llvm::BitVector &indices) { 1038 SmallVector<Type, 4> types; 1039 SmallVector<Value, 4> operands; 1040 // Construct the while-loop with a parameter for each index. 1041 Type indexType = rewriter.getIndexType(); 1042 for (unsigned b = 0, be = indices.size(); b < be; b++) { 1043 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1044 unsigned tensor = merger.tensor(b); 1045 assert(idx == merger.index(b)); 1046 types.push_back(indexType); 1047 assert(codegen.pidxs[tensor][idx].getType().isa<IndexType>() && 1048 "type mismatch for sparse index"); 1049 operands.push_back(codegen.pidxs[tensor][idx]); 1050 } 1051 } 1052 if (needsUniv) { 1053 types.push_back(indexType); 1054 assert(codegen.loops[idx].getType().isa<IndexType>() && 1055 "type mismatch for universal index"); 1056 operands.push_back(codegen.loops[idx]); 1057 } 1058 Location loc = op.getLoc(); 1059 scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands); 1060 Block *before = rewriter.createBlock(&whileOp.before(), {}, types); 1061 Block *after = rewriter.createBlock(&whileOp.after(), {}, types); 1062 1063 // Build the "before" region, which effectively consists 1064 // of a conjunction of "i < upper" tests on all induction. 1065 rewriter.setInsertionPointToStart(&whileOp.before().front()); 1066 Value cond; 1067 unsigned o = 0; 1068 for (unsigned b = 0, be = indices.size(); b < be; b++) { 1069 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1070 unsigned tensor = merger.tensor(b); 1071 assert(idx == merger.index(b)); 1072 Value op1 = before->getArgument(o); 1073 Value op2 = codegen.highs[tensor][idx]; 1074 Value opc = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 1075 op1, op2); 1076 cond = cond ? rewriter.create<arith::AndIOp>(loc, cond, opc) : opc; 1077 codegen.pidxs[tensor][idx] = after->getArgument(o++); 1078 } 1079 } 1080 if (needsUniv) 1081 codegen.loops[idx] = after->getArgument(o++); 1082 assert(o == operands.size()); 1083 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 1084 rewriter.setInsertionPointToStart(&whileOp.after().front()); 1085 return whileOp; 1086 } 1087 1088 /// Generates a for-loop or a while-loop, depending on whether it implements 1089 /// singleton iteration or co-iteration over the given conjunction. 1090 static Operation *genLoop(Merger &merger, CodeGen &codegen, 1091 PatternRewriter &rewriter, linalg::GenericOp op, 1092 std::vector<unsigned> &topSort, unsigned at, 1093 bool needsUniv, llvm::BitVector &indices) { 1094 unsigned idx = topSort[at]; 1095 if (indices.count() == 1) { 1096 bool isOuter = at == 0; 1097 bool isInner = at == topSort.size() - 1; 1098 return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx, 1099 indices); 1100 } 1101 genReductionEnd(merger, codegen, rewriter, op); // cannot chain 1102 return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices); 1103 } 1104 1105 /// Generates the local variables for this loop, consisting of the sparse 1106 /// indices, restored universal dense index, and dense positions. 1107 static void genLocals(Merger &merger, CodeGen &codegen, 1108 PatternRewriter &rewriter, linalg::GenericOp op, 1109 std::vector<unsigned> &topSort, unsigned at, 1110 bool needsUniv, llvm::BitVector &locals) { 1111 Location loc = op.getLoc(); 1112 unsigned idx = topSort[at]; 1113 1114 // Initialize sparse indices. 1115 Value min; 1116 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1117 if (locals[b] && merger.isDim(b, Dim::kSparse)) { 1118 unsigned tensor = merger.tensor(b); 1119 assert(idx == merger.index(b)); 1120 Value ptr = codegen.indices[tensor][idx]; 1121 Value s = codegen.pidxs[tensor][idx]; 1122 Value load = genLoad(codegen, rewriter, loc, ptr, s); 1123 codegen.idxs[tensor][idx] = load; 1124 if (!needsUniv) { 1125 if (min) { 1126 Value cmp = rewriter.create<arith::CmpIOp>( 1127 loc, arith::CmpIPredicate::ult, load, min); 1128 min = rewriter.create<SelectOp>(loc, cmp, load, min); 1129 } else { 1130 min = load; 1131 } 1132 } 1133 } 1134 } 1135 1136 // Merge dense universal index over minimum. 1137 if (min) { 1138 assert(!needsUniv); 1139 codegen.loops[idx] = min; 1140 } 1141 1142 // Initialize dense positions. Note that we generate dense indices of the 1143 // output tensor unconditionally, since they may not appear in the lattice, 1144 // but may be needed for linearized codegen. 1145 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1146 if ((locals[b] || merger.isOutTensor(b, idx)) && 1147 merger.isDim(b, Dim::kDense)) { 1148 unsigned tensor = merger.tensor(b); 1149 assert(idx == merger.index(b)); 1150 unsigned pat = at; 1151 for (; pat != 0; pat--) 1152 if (codegen.pidxs[tensor][topSort[pat - 1]]) 1153 break; 1154 Value p = (pat == 0) ? rewriter.create<arith::ConstantIndexOp>(loc, 0) 1155 : codegen.pidxs[tensor][topSort[pat - 1]]; 1156 codegen.pidxs[tensor][idx] = genAddress( 1157 codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]); 1158 } 1159 } 1160 } 1161 1162 /// Generates the induction structure for a while-loop. 1163 static void genWhileInduction(Merger &merger, CodeGen &codegen, 1164 PatternRewriter &rewriter, linalg::GenericOp op, 1165 unsigned idx, bool needsUniv, 1166 llvm::BitVector &induction, ResultRange results) { 1167 Location loc = op.getLoc(); 1168 unsigned o = 0; 1169 SmallVector<Value, 4> operands; 1170 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 1171 for (unsigned b = 0, be = induction.size(); b < be; b++) { 1172 if (induction[b] && merger.isDim(b, Dim::kSparse)) { 1173 unsigned tensor = merger.tensor(b); 1174 assert(idx == merger.index(b)); 1175 Value op1 = codegen.idxs[tensor][idx]; 1176 Value op2 = codegen.loops[idx]; 1177 Value op3 = codegen.pidxs[tensor][idx]; 1178 Value cmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 1179 op1, op2); 1180 Value add = rewriter.create<arith::AddIOp>(loc, op3, one); 1181 operands.push_back(rewriter.create<SelectOp>(loc, cmp, add, op3)); 1182 codegen.pidxs[tensor][idx] = results[o++]; 1183 } 1184 } 1185 if (needsUniv) { 1186 operands.push_back( 1187 rewriter.create<arith::AddIOp>(loc, codegen.loops[idx], one)); 1188 codegen.loops[idx] = results[o++]; 1189 } 1190 assert(o == operands.size()); 1191 rewriter.create<scf::YieldOp>(loc, operands); 1192 } 1193 1194 /// Generates a single if-statement within a while-loop. 1195 static scf::IfOp genIf(Merger &merger, CodeGen &codegen, 1196 PatternRewriter &rewriter, linalg::GenericOp op, 1197 unsigned idx, llvm::BitVector &conditions) { 1198 Location loc = op.getLoc(); 1199 Value cond; 1200 for (unsigned b = 0, be = conditions.size(); b < be; b++) { 1201 if (conditions[b]) { 1202 unsigned tensor = merger.tensor(b); 1203 assert(idx == merger.index(b)); 1204 Value clause; 1205 if (merger.isDim(b, Dim::kSparse)) { 1206 Value op1 = codegen.idxs[tensor][idx]; 1207 Value op2 = codegen.loops[idx]; 1208 clause = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 1209 op1, op2); 1210 } else { 1211 clause = rewriter.create<arith::ConstantIntOp>(loc, 1, 1); // true 1212 } 1213 cond = cond ? rewriter.create<arith::AndIOp>(loc, cond, clause) : clause; 1214 } 1215 } 1216 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ true); 1217 rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); 1218 return ifOp; 1219 } 1220 1221 //===----------------------------------------------------------------------===// 1222 // Sparse compiler synthesis methods (loop sequence). 1223 //===----------------------------------------------------------------------===// 1224 1225 /// Starts a loop sequence at given level. Returns true if 1226 /// the universal loop index must be maintained at this level. 1227 static bool startLoopSeq(Merger &merger, CodeGen &codegen, 1228 PatternRewriter &rewriter, linalg::GenericOp op, 1229 std::vector<unsigned> &topSort, unsigned exp, 1230 unsigned at, unsigned idx, unsigned ldx, 1231 unsigned lts) { 1232 assert(codegen.curVecLength == 1); 1233 // Emit invariants at this loop sequence level. 1234 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/true); 1235 // Emit further intitialization at this loop sequence level. 1236 unsigned l0 = merger.set(lts)[0]; 1237 if (genInit(merger, codegen, rewriter, op, topSort, at, 1238 merger.lat(l0).bits)) { 1239 // Maintain the universal index only if it is actually 1240 // consumed by a subsequent lattice point. 1241 unsigned lsize = merger.set(lts).size(); 1242 for (unsigned i = 1; i < lsize; i++) { 1243 unsigned li = merger.set(lts)[i]; 1244 if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) 1245 return true; 1246 } 1247 } 1248 return false; 1249 } 1250 1251 /// Starts a single loop in current sequence. 1252 static Operation *startLoop(Merger &merger, CodeGen &codegen, 1253 PatternRewriter &rewriter, linalg::GenericOp op, 1254 std::vector<unsigned> &topSort, unsigned at, 1255 unsigned li, bool needsUniv) { 1256 assert(codegen.curVecLength == 1); 1257 // Emit the for/while-loop control. 1258 Operation *loop = genLoop(merger, codegen, rewriter, op, topSort, at, 1259 needsUniv, merger.lat(li).simple); 1260 // Emit the locals for this loop. 1261 genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, 1262 merger.lat(li).bits); 1263 return loop; 1264 } 1265 1266 /// Ends a single loop in current sequence. Returns new values for needsUniv. 1267 static bool endLoop(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1268 linalg::GenericOp op, Operation *loop, unsigned idx, 1269 unsigned li, bool needsUniv) { 1270 codegen.curVecLength = 1; 1271 // End a while-loop. 1272 if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) { 1273 rewriter.setInsertionPointToEnd(&whileOp.after().front()); 1274 genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv, 1275 merger.lat(li).bits, whileOp.results()); 1276 return needsUniv; 1277 } 1278 // End a for-loop. 1279 if (codegen.redVal) { 1280 rewriter.create<scf::YieldOp>(op.getLoc(), codegen.redVal); 1281 codegen.redVal = loop->getResult(0); 1282 } 1283 return false; 1284 } 1285 1286 /// Ends a loop sequence at given level. 1287 static void endLoopSeq(Merger &merger, CodeGen &codegen, 1288 PatternRewriter &rewriter, linalg::GenericOp op, 1289 unsigned exp, unsigned idx, unsigned ldx) { 1290 assert(codegen.curVecLength == 1); 1291 // Finalize any pending reduction. 1292 genReductionEnd(merger, codegen, rewriter, op); 1293 // Unmark bookkeeping of invariants and loop index. 1294 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/false); 1295 codegen.loops[idx] = Value(); 1296 } 1297 1298 /// Recursively generates code while computing iteration lattices in order 1299 /// to manage the complexity of implementing co-iteration over unions 1300 /// and intersections of sparse iterations spaces. 1301 static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1302 linalg::GenericOp op, std::vector<unsigned> &topSort, 1303 unsigned exp, unsigned at) { 1304 // At each leaf, assign remaining tensor (sub)expression to output tensor. 1305 if (at == topSort.size()) { 1306 Value rhs = genExp(merger, codegen, rewriter, op, exp); 1307 genTensorStore(merger, codegen, rewriter, op, rhs); 1308 return; 1309 } 1310 1311 // Construct iteration lattices for current loop index, with L0 at top. 1312 unsigned idx = topSort[at]; 1313 unsigned ldx = at == 0 ? -1u : topSort[at - 1]; 1314 unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx)); 1315 1316 // Start a loop sequence. 1317 bool needsUniv = startLoopSeq(merger, codegen, rewriter, op, topSort, exp, at, 1318 idx, ldx, lts); 1319 1320 // Emit a loop for every lattice point L0 >= Li in this loop sequence. 1321 unsigned lsize = merger.set(lts).size(); 1322 for (unsigned i = 0; i < lsize; i++) { 1323 // Start a loop. 1324 unsigned li = merger.set(lts)[i]; 1325 Operation *loop = 1326 startLoop(merger, codegen, rewriter, op, topSort, at, li, needsUniv); 1327 1328 // Visit all lattices points with Li >= Lj to generate the 1329 // loop-body, possibly with if statements for coiteration. 1330 bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr; 1331 for (unsigned j = 0; j < lsize; j++) { 1332 unsigned lj = merger.set(lts)[j]; 1333 unsigned ej = merger.lat(lj).exp; 1334 if (li == lj || merger.latGT(li, lj)) { 1335 // Recurse into body of each branch. 1336 if (isWhile) { 1337 scf::IfOp ifOp = 1338 genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple); 1339 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1340 rewriter.setInsertionPointToStart(&ifOp.elseRegion().front()); 1341 } else { 1342 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1343 } 1344 } 1345 } 1346 1347 // End a loop. 1348 needsUniv = 1349 endLoop(merger, codegen, rewriter, op, loop, idx, li, needsUniv); 1350 rewriter.setInsertionPointAfter(loop); 1351 } 1352 1353 // End a loop sequence. 1354 endLoopSeq(merger, codegen, rewriter, op, exp, idx, ldx); 1355 } 1356 1357 /// Converts the result computed by the sparse kernel into the required form. 1358 static void genResult(Merger &merger, CodeGen &codegen, 1359 PatternRewriter &rewriter, linalg::GenericOp op) { 1360 Location loc = op.getLoc(); 1361 OpOperand *lhs = op.getOutputOperand(0); 1362 Type resType = lhs->get().getType(); 1363 unsigned tensor = lhs->getOperandNumber(); 1364 auto map = op.getTiedIndexingMap(lhs); 1365 auto enc = getSparseTensorEncoding(resType); 1366 Value result = codegen.buffers.back(); // value array 1367 if (enc) { 1368 // The sparse annotation unambigiously defines the arrays needed 1369 // to "reconstruct" the sparse tensor from the storage scheme 1370 // (even though lowering should never need this eventually). 1371 SmallVector<Value, 4> args; 1372 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 1373 AffineExpr a = map.getResult(perm(enc, d)); 1374 if (a.getKind() != AffineExprKind::DimId) 1375 continue; // compound 1376 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 1377 if (merger.isDim(tensor, idx, Dim::kSparse)) { 1378 args.push_back(codegen.pointers[tensor][idx]); 1379 args.push_back(codegen.indices[tensor][idx]); 1380 } 1381 } 1382 args.push_back(result); 1383 result = rewriter.create<ToTensorOp>(loc, resType, args); 1384 } else { 1385 // To "reconstruct" an non-annotated tensor, sipmly load it 1386 // from the bufferized value. 1387 result = rewriter.create<memref::TensorLoadOp>(loc, resType, result); 1388 } 1389 rewriter.replaceOp(op, result); 1390 } 1391 1392 //===----------------------------------------------------------------------===// 1393 // Sparse compiler rewriting methods. 1394 //===----------------------------------------------------------------------===// 1395 1396 namespace { 1397 1398 /// Sparse rewriting rule for generic Lingalg operation. 1399 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> { 1400 public: 1401 GenericOpSparsifier(MLIRContext *context, SparsificationOptions o) 1402 : OpRewritePattern<linalg::GenericOp>(context), options(o) {} 1403 1404 LogicalResult matchAndRewrite(linalg::GenericOp op, 1405 PatternRewriter &rewriter) const override { 1406 // Detects sparse annotations and translate the per-dimension sparsity 1407 // information for all tensors to loop indices in the kernel. 1408 assert(op.getNumOutputs() == 1); 1409 unsigned numTensors = op.getNumInputsAndOutputs(); 1410 unsigned numLoops = op.iterator_types().getValue().size(); 1411 Merger merger(numTensors, numLoops); 1412 if (!findSparseAnnotations(merger, op)) 1413 return failure(); 1414 1415 // Computes a topologically sorted iteration graph to ensure 1416 // tensors are visited in natural index order. Fails on cycles. 1417 // This assumes that higher-level passes have already put the 1418 // tensors in each tensor expression in a feasible order. 1419 std::vector<unsigned> topSort; 1420 if (!computeIterationGraph(merger, op, topSort, 1421 SortMask::kIncludeUndef | 1422 SortMask::kIncludeDense) && 1423 !computeIterationGraph(merger, op, topSort, SortMask::kIncludeUndef) && 1424 !computeIterationGraph(merger, op, topSort, SortMask::kIncludeDense) && 1425 !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly)) 1426 return failure(); 1427 1428 // Builds the tensor expression for the Linalg operation in SSA form. 1429 Optional<unsigned> exp = merger.buildTensorExpFromLinalg(op); 1430 if (!exp.hasValue()) 1431 return failure(); 1432 1433 // Rejects an inadmissable tensor expression. 1434 if (!isAdmissableTensorExp(merger, op, exp.getValue())) 1435 return failure(); 1436 1437 // Recursively generates code. 1438 CodeGen codegen(options, numTensors, numLoops); 1439 genBuffers(merger, codegen, rewriter, op); 1440 genStmt(merger, codegen, rewriter, op, topSort, exp.getValue(), 0); 1441 genResult(merger, codegen, rewriter, op); 1442 return success(); 1443 } 1444 1445 private: 1446 /// Options to control sparse code generation. 1447 SparsificationOptions options; 1448 }; 1449 1450 } // namespace 1451 1452 /// Populates the given patterns list with rewriting rules required for 1453 /// the sparsification of linear algebra operations. 1454 void mlir::populateSparsificationPatterns( 1455 RewritePatternSet &patterns, const SparsificationOptions &options) { 1456 patterns.add<GenericOpSparsifier>(patterns.getContext(), options); 1457 } 1458