1 //===- Sparsification.cpp - Implementation of sparsification --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements converting sparse tensor types to actual sparse code. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/Linalg/Utils/Utils.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/Dialect/SCF/SCF.h" 20 #include "mlir/Dialect/SCF/Transforms.h" 21 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 22 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 23 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 24 #include "mlir/Dialect/StandardOps/IR/Ops.h" 25 #include "mlir/Dialect/Vector/VectorOps.h" 26 #include "mlir/IR/Matchers.h" 27 #include "mlir/IR/TensorEncoding.h" 28 #include "llvm/ADT/SmallBitVector.h" 29 30 using namespace mlir; 31 using namespace mlir::sparse_tensor; 32 33 //===----------------------------------------------------------------------===// 34 // Declarations of data structures. 35 //===----------------------------------------------------------------------===// 36 37 namespace { 38 39 // Iteration graph sorting. 40 enum SortMask { kSparseOnly = 0x0, kIncludeDense = 0x1, kIncludeUndef = 0x2 }; 41 42 // Reduction kinds. 43 enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor }; 44 45 // Code generation. 46 struct CodeGen { 47 CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops) 48 : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors), 49 pointers(numTensors, std::vector<Value>(numLoops)), 50 indices(numTensors, std::vector<Value>(numLoops)), 51 highs(numTensors, std::vector<Value>(numLoops)), 52 pidxs(numTensors, std::vector<Value>(numLoops)), 53 idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(), 54 redKind(kNoReduc), curVecLength(1), curVecMask() {} 55 /// Sparsification options. 56 SparsificationOptions options; 57 /// Universal dense indices and upper bounds (by index). The loops array 58 /// is updated with the value of the universal dense index in the current 59 /// loop. The sizes array is set once with the inferred dimension sizes. 60 std::vector<Value> loops; 61 std::vector<Value> sizes; 62 /// Buffers for storing dense and sparse numerical values (by tensor). 63 /// This array is set once during bufferization of all tensors. 64 std::vector<Value> buffers; 65 /// Sparse storage schemes (1-D): pointers and indices (by tensor and index). 66 /// This array is set once during bufferization of all sparse tensors. 67 std::vector<std::vector<Value>> pointers; 68 std::vector<std::vector<Value>> indices; 69 /// Sparse iteration information (by tensor and index). These arrays 70 /// are updated to remain current within the current loop. 71 std::vector<std::vector<Value>> highs; 72 std::vector<std::vector<Value>> pidxs; 73 std::vector<std::vector<Value>> idxs; 74 /// Current reduction, updated during code generation. When indices of a 75 /// reduction are exhausted, all inner loops can use a scalarized reduction. 76 unsigned redExp; 77 Value redVal; 78 Reduction redKind; 79 // Current vector length and mask. 80 unsigned curVecLength; 81 Value curVecMask; 82 }; 83 84 } // namespace 85 86 //===----------------------------------------------------------------------===// 87 // Sparse compiler analysis methods. 88 //===----------------------------------------------------------------------===// 89 90 /// Helper method to apply dimension ordering permutation. 91 static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d) { 92 if (enc) { 93 auto order = enc.getDimOrdering(); 94 if (order) { 95 assert(order.isPermutation()); 96 return order.getDimPosition(d); 97 } 98 } 99 return d; 100 } 101 102 /// Helper method to translate dim level type to internal representation. 103 static Dim toDim(const SparseTensorEncodingAttr &enc, unsigned d) { 104 if (enc) { 105 SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d]; 106 if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed) 107 return Dim::kSparse; 108 if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton) 109 return Dim::kSingle; 110 } 111 return Dim::kDense; 112 } 113 114 /// Helper method to inspect affine expressions. Rejects cases where the 115 /// same index is used more than once. Also rejects affine expressions 116 /// that are not a direct index for annotated tensors. 117 // TODO: accept more affine cases for sparse tensors 118 static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a, Dim dim, 119 bool isDense) { 120 switch (a.getKind()) { 121 case AffineExprKind::DimId: { 122 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 123 if (!merger.isDim(tensor, idx, Dim::kUndef)) 124 return false; // used more than once 125 merger.setDim(tensor, idx, dim); 126 return true; 127 } 128 case AffineExprKind::Add: 129 case AffineExprKind::Mul: { 130 if (!isDense) 131 return false; 132 auto binOp = a.cast<AffineBinaryOpExpr>(); 133 return findAffine(merger, tensor, binOp.getLHS(), dim, isDense) && 134 findAffine(merger, tensor, binOp.getRHS(), dim, isDense); 135 } 136 case AffineExprKind::Constant: 137 return isDense; 138 default: 139 return false; 140 } 141 } 142 143 /// Helper method to inspect sparse encodings in the tensor types. 144 /// Fills the per-dimension sparsity information for all tensors. 145 /// Returns true if the sparse annotations and affine subscript 146 /// expressions of all tensors are admissable. Returns false if 147 /// no annotations are found or inadmissable constructs occur. 148 static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) { 149 bool annotated = false; 150 for (OpOperand *t : op.getInputAndOutputOperands()) { 151 auto map = op.getTiedIndexingMap(t); 152 auto enc = getSparseTensorEncoding(t->get().getType()); 153 if (enc) 154 annotated = true; 155 assert(map.getNumResults() == op.getRank(t)); 156 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 157 unsigned tensor = t->getOperandNumber(); 158 AffineExpr a = map.getResult(perm(enc, d)); 159 if (!findAffine(merger, tensor, a, toDim(enc, d), !enc)) 160 return false; // inadmissable affine expression 161 } 162 } 163 return annotated; 164 } 165 166 /// A DFS helper to compute a topological sort. Note that recursion is 167 /// bounded by the number of implicit loops, which is always small. 168 /// Returns false when a cycle is detected. 169 static bool topSortDFS(unsigned i, std::vector<unsigned> &visit, 170 std::vector<unsigned> &topSort, 171 std::vector<std::vector<bool>> &adjM) { 172 if (visit[i] != 0) 173 return visit[i] != 1; // 1 denotes cycle! 174 visit[i] = 1; 175 for (unsigned j = 0, e = visit.size(); j < e; j++) 176 if (adjM[i][j]) 177 if (!topSortDFS(j, visit, topSort, adjM)) 178 return false; 179 visit[i] = 2; 180 topSort.push_back(i); 181 return true; 182 } 183 184 /// Helper method to add all constraints from the indices in one affine 185 /// expression before all indices in the other affine expression. For 186 /// example i0+i1 < i2+i3+1 yields i0<i2, i0<i3, i1<i2, and i1<i3. 187 static void addAffineOrderings(std::vector<std::vector<bool>> &adjM, 188 AffineExpr a, AffineExpr b, unsigned fidx) { 189 switch (a.getKind()) { 190 case AffineExprKind::DimId: { 191 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 192 if (b) 193 addAffineOrderings(adjM, b, AffineExpr(), idx); 194 else 195 adjM[fidx][idx] = true; 196 break; 197 } 198 case AffineExprKind::Add: 199 case AffineExprKind::Mul: { 200 auto binOp = a.cast<AffineBinaryOpExpr>(); 201 addAffineOrderings(adjM, binOp.getLHS(), b, fidx); 202 addAffineOrderings(adjM, binOp.getRHS(), b, fidx); 203 break; 204 } 205 default: 206 break; 207 } 208 } 209 210 /// Computes a topologically sorted iteration graph for the linalg operation. 211 /// Ensures all tensors are visited in natural index order. This is essential 212 /// for sparse storage formats since these only support access along fixed 213 /// dimensions. Even for dense storage formats, however, the natural index 214 /// order yields innermost unit-stride access with better spatial locality. 215 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op, 216 std::vector<unsigned> &topSort, 217 unsigned mask) { 218 // Set up an n x n from/to adjacency matrix of the iteration graph 219 // for the implicit loop indices i_0 .. i_n-1. 220 unsigned n = op.getNumLoops(); 221 std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false)); 222 223 // Iterate over the indexing maps of every tensor in the tensor expression. 224 for (OpOperand *t : op.getInputAndOutputOperands()) { 225 auto map = op.getTiedIndexingMap(t); 226 auto enc = getSparseTensorEncoding(t->get().getType()); 227 assert(map.getNumDims() == n); 228 // Skip dense tensor constraints when not requested. 229 if (!(mask & SortMask::kIncludeDense) && !enc) 230 continue; 231 // Each tensor expression and optional dimension ordering (row-major 232 // by default) puts an ordering constraint on the loop indices. For 233 // example, the tensor expresion A_ijk forces the ordering i < j < k 234 // on the loop indices if no explicit dimension ordering is given. 235 for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) { 236 AffineExpr f = map.getResult(perm(enc, d - 1)); 237 AffineExpr t = map.getResult(perm(enc, d)); 238 addAffineOrderings(adjM, f, t, 0); 239 } 240 // Push unrelated loops into sparse iteration space, so these 241 // will be skipped more often. 242 if (mask & SortMask::kIncludeUndef) { 243 unsigned tensor = t->getOperandNumber(); 244 for (unsigned i = 0; i < n; i++) 245 if (merger.isDim(tensor, i, Dim::kSparse)) 246 for (unsigned j = 0; j < n; j++) 247 if (merger.isDim(tensor, j, Dim::kUndef)) 248 adjM[i][j] = true; 249 } 250 } 251 252 // Topologically sort the iteration graph to determine loop order. 253 // Report failure for a cyclic iteration graph. 254 topSort.clear(); 255 topSort.reserve(n); 256 std::vector<unsigned> visit(n, 0); 257 for (unsigned i = 0; i < n; i++) 258 if (visit[i] == 0) 259 if (!topSortDFS(i, visit, topSort, adjM)) 260 return false; // cycle! 261 std::reverse(std::begin(topSort), std::end(topSort)); 262 return true; 263 } 264 265 /// Returns true if tensor has an in-place annotation. 266 static bool isInPlace(Value val) { 267 if (auto arg = val.dyn_cast<BlockArgument>()) 268 if (auto funcOp = dyn_cast<FuncOp>(arg.getOwner()->getParentOp())) 269 if (auto attr = funcOp.getArgAttrOfType<BoolAttr>( 270 arg.getArgNumber(), 271 linalg::comprehensive_bufferize::BufferizableOpInterface:: 272 kInplaceableAttrName)) 273 return attr.getValue(); 274 return false; 275 } 276 277 /// Returns true if tensor materializes into the computation. 278 static bool isMaterializing(Value val) { 279 return val.getDefiningOp<linalg::InitTensorOp>() || 280 val.getDefiningOp<InitOp>(); 281 } 282 283 /// Returns true when the tensor expression is admissable for codegen. 284 /// Since all sparse input tensors are admissable, we just need to check 285 /// whether the output tensor in the tensor expression codegen is admissable. 286 static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op, 287 unsigned exp) { 288 OpOperand *lhs = op.getOutputOperand(0); 289 unsigned tensor = lhs->getOperandNumber(); 290 auto enc = getSparseTensorEncoding(lhs->get().getType()); 291 // An non-annotated output tensor is assumed dense, and becomes a random 292 // access n-dim memref. Admissable since insertions cannot occur. 293 if (!enc) 294 return true; 295 // An all-dense annotated "sparse" output tensor becomes a linearized random 296 // access 1-dim memref. Also admissable since insertions cannot occur. 297 bool allDense = true; 298 unsigned numLoops = op.iterator_types().getValue().size(); 299 for (unsigned i = 0; i < numLoops; i++) 300 if (merger.isDim(tensor, i, Dim::kSparse)) { 301 allDense = false; 302 break; 303 } 304 if (allDense) 305 return true; 306 // A tensor expression with a sparse output tensor that changes its values 307 // but not its nonzero structure, an operation called "simply dynamic" in 308 // [Bik96,Ch9], is also admissable without special codegen, provided 309 // the tensor's underlying sparse storage scheme can be modified in place. 310 if (merger.isConjunction(tensor, exp)) 311 return isInPlace(lhs->get()); 312 // Reject for now since this requires changes to the nonzero structure. 313 // TODO: implement "workspaces" [Kjolstad2019] 314 return false; 315 } 316 317 //===----------------------------------------------------------------------===// 318 // Sparse compiler synthesis methods (reductions). 319 //===----------------------------------------------------------------------===// 320 321 /// Maps reduction kind to name encoding. 322 static StringRef getReductionName(Reduction kind) { 323 switch (kind) { 324 case kNoReduc: 325 break; 326 case kSum: 327 return "add"; 328 case kProduct: 329 return "mul"; 330 case kAnd: 331 return "and"; 332 case kOr: 333 return "or"; 334 case kXor: 335 return "xor"; 336 } 337 llvm_unreachable("unknown reduction kind"); 338 } 339 340 /// Maps operation to reduction. 341 static Reduction getReduction(Kind kind) { 342 switch (kind) { 343 case Kind::kAddF: 344 case Kind::kAddI: 345 case Kind::kSubF: 346 case Kind::kSubI: 347 return kSum; 348 case Kind::kMulF: 349 case Kind::kMulI: 350 return kProduct; 351 case Kind::kAndI: 352 return kAnd; 353 case Kind::kOrI: 354 return kOr; 355 case Kind::kXorI: 356 return kXor; 357 default: 358 llvm_unreachable("unexpected reduction operator"); 359 } 360 } 361 362 /// Generates an initial value for a vector reduction, following the scheme 363 /// given in Chapter 5 of "The Software Vectorization Handbook", where the 364 /// initial scalar value is correctly embedded in the vector reduction value, 365 /// and a straightforward horizontal reduction will complete the operation. 366 static Value genVectorReducInit(CodeGen &codegen, PatternRewriter &rewriter, 367 Location loc, VectorType vtp) { 368 Value r = codegen.redVal; 369 switch (codegen.redKind) { 370 case kNoReduc: 371 break; 372 case kSum: 373 case kXor: { 374 // Initialize reduction vector to: | 0 | .. | 0 | r | 375 Attribute zero = rewriter.getZeroAttr(vtp); 376 Value vec = rewriter.create<arith::ConstantOp>(loc, vtp, zero); 377 return rewriter.create<vector::InsertElementOp>(loc, r, vec, 0); 378 } 379 case kProduct: { 380 // Initialize reduction vector to: | 1 | .. | 1 | r | 381 Type etp = vtp.getElementType(); 382 Attribute one; 383 if (etp.isa<FloatType>()) 384 one = rewriter.getFloatAttr(etp, 1.0); 385 else 386 one = rewriter.getIntegerAttr(etp, 1); 387 Value vec = rewriter.create<arith::ConstantOp>( 388 loc, vtp, DenseElementsAttr::get(vtp, one)); 389 return rewriter.create<vector::InsertElementOp>(loc, r, vec, 0); 390 } 391 case kAnd: 392 case kOr: 393 // Initialize reduction vector to: | r | .. | r | r | 394 return rewriter.create<vector::BroadcastOp>(loc, vtp, r); 395 } 396 llvm_unreachable("unknown reduction kind"); 397 } 398 399 /// Generates final value for a vector reduction. 400 static Value genVectorReducEnd(CodeGen &codegen, PatternRewriter &rewriter, 401 Location loc, VectorType vtp) { 402 StringRef name = getReductionName(codegen.redKind); 403 StringAttr kind = rewriter.getStringAttr(name); 404 return rewriter.create<vector::ReductionOp>(loc, vtp.getElementType(), kind, 405 codegen.redVal, ValueRange{}); 406 } 407 408 /// Updates scalarized reduction value. 409 static void updateReduc(Merger &merger, CodeGen &codegen, Value reduc) { 410 assert(codegen.redKind != kNoReduc); 411 codegen.redVal = merger.exp(codegen.redExp).val = reduc; 412 } 413 414 //===----------------------------------------------------------------------===// 415 // Sparse compiler synthesis methods (statements and expressions). 416 //===----------------------------------------------------------------------===// 417 418 /// Maps sparse integer option to actual integral storage type. 419 static Type genIntType(PatternRewriter &rewriter, unsigned width) { 420 if (width == 0) 421 return rewriter.getIndexType(); 422 return rewriter.getIntegerType(width); 423 } 424 425 /// Generates buffer for the output tensor. Note that all sparse kernels 426 /// assume that when all elements are written to (viz. x(i) = y(i) * z(i)), 427 /// the output buffer is already initialized to all zeroes and only nonzeroes 428 /// values are computed and written out. For updates (viz. x(i) += y(i) * z(i)), 429 /// only nonzeroes values are used for the updates and no assumption on the 430 /// original contents of the output buffer is necessary.. 431 static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter, 432 linalg::GenericOp op, MemRefType denseTp, 433 ArrayRef<Value> args) { 434 Location loc = op.getLoc(); 435 Value tensor = op.getOutputOperand(0)->get(); 436 // The output tensor simply could materialize from the buffer that will 437 // be generated for the tensor present in the outs() clause. This has 438 // the major advantage that the sparse kernel only updates the nonzero 439 // positions for the output tensor. 440 if (isInPlace(tensor)) 441 return rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 442 // By default, a new buffer is allocated which is initialized to the 443 // tensor defined in the outs() clause. This is always correct but 444 // introduces a dense initialization component that may negatively 445 // impact the running complexity of the sparse kernel. If the tensor 446 // materializes into the computation, we need to preserve the zero 447 // initialization assumption of all sparse output buffers. 448 if (isMaterializing(tensor)) { 449 Type tp = denseTp.getElementType(); 450 Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args); 451 Value zero = 452 rewriter.create<arith::ConstantOp>(loc, tp, rewriter.getZeroAttr(tp)); 453 rewriter.create<linalg::FillOp>(loc, zero, alloc); 454 return alloc; 455 } 456 Value init = rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 457 Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args); 458 rewriter.create<memref::CopyOp>(loc, init, alloc); 459 return alloc; 460 } 461 462 /// Local bufferization of all dense and sparse data structures. 463 /// This code enables testing the first prototype sparse compiler. 464 // TODO: replace this with a proliferated bufferization strategy 465 static void genBuffers(Merger &merger, CodeGen &codegen, 466 PatternRewriter &rewriter, linalg::GenericOp op) { 467 Location loc = op.getLoc(); 468 assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1); 469 // For every tensor, find lower and upper bound on dimensions, set the 470 // same bounds on loop indices, and obtain dense or sparse buffer(s). 471 SmallVector<Value, 4> args; 472 for (OpOperand *t : op.getInputAndOutputOperands()) { 473 unsigned tensor = t->getOperandNumber(); 474 auto shape = op.getShape(t); 475 auto map = op.getTiedIndexingMap(t); 476 auto enc = getSparseTensorEncoding(t->get().getType()); 477 // Scan all dimensions of current tensor. 478 args.clear(); 479 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 480 AffineExpr a = map.getResult(perm(enc, d)); 481 if (a.getKind() != AffineExprKind::DimId) 482 continue; // compound 483 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 484 // Handle sparse storage schemes. 485 if (merger.isDim(tensor, idx, Dim::kSparse)) { 486 auto dynShape = {ShapedType::kDynamicSize}; 487 auto ptrTp = MemRefType::get( 488 dynShape, genIntType(rewriter, enc.getPointerBitWidth())); 489 auto indTp = MemRefType::get( 490 dynShape, genIntType(rewriter, enc.getIndexBitWidth())); 491 Value dim = rewriter.create<arith::ConstantIndexOp>(loc, d); 492 // Generate sparse primitives to obtains pointer and indices. 493 codegen.pointers[tensor][idx] = 494 rewriter.create<ToPointersOp>(loc, ptrTp, t->get(), dim); 495 codegen.indices[tensor][idx] = 496 rewriter.create<ToIndicesOp>(loc, indTp, t->get(), dim); 497 } 498 // Find upper bound in current dimension. 499 unsigned p = perm(enc, d); 500 Value up = linalg::createOrFoldDimOp(rewriter, loc, t->get(), p); 501 if (shape[p] == MemRefType::kDynamicSize) 502 args.push_back(up); 503 assert(codegen.highs[tensor][idx] == nullptr); 504 codegen.sizes[idx] = codegen.highs[tensor][idx] = up; 505 } 506 // Perform the required bufferization. Dense inputs materialize 507 // from the input tensors. Dense outputs need special handling. 508 // Sparse inputs use sparse primitives to obtain the values. 509 // We also accept in-place all-dense annotated "sparse" outputs. 510 Type elementType = getElementTypeOrSelf(t->get().getType()); 511 if (!enc) { 512 // Non-annotated dense tensors. 513 auto denseTp = MemRefType::get(shape, elementType); 514 if (tensor < op.getNumInputs()) 515 codegen.buffers[tensor] = 516 rewriter.create<memref::BufferCastOp>(loc, denseTp, t->get()); 517 else 518 codegen.buffers[tensor] = 519 genOutputBuffer(codegen, rewriter, op, denseTp, args); 520 } else { 521 // Annotated sparse tensors. 522 auto dynShape = {ShapedType::kDynamicSize}; 523 auto sparseTp = MemRefType::get(dynShape, elementType); 524 codegen.buffers[tensor] = 525 rewriter.create<ToValuesOp>(loc, sparseTp, t->get()); 526 } 527 } 528 } 529 530 /// Constructs vector type. 531 static VectorType vectorType(CodeGen &codegen, Type etp) { 532 return VectorType::get(codegen.curVecLength, etp); 533 } 534 535 /// Constructs vector type from pointer. 536 static VectorType vectorType(CodeGen &codegen, Value ptr) { 537 return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType()); 538 } 539 540 /// Constructs vector iteration mask. 541 static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter, 542 Value iv, Value lo, Value hi, Value step) { 543 Location loc = iv.getLoc(); 544 VectorType mtp = vectorType(codegen, genIntType(rewriter, 1)); 545 // Special case if the vector length evenly divides the trip count (for 546 // example, "for i = 0, 128, 16"). A constant all-true mask is generated 547 // so that all subsequent masked memory operations are immediately folded 548 // into unconditional memory operations. 549 IntegerAttr loInt, hiInt, stepInt; 550 if (matchPattern(lo, m_Constant(&loInt)) && 551 matchPattern(hi, m_Constant(&hiInt)) && 552 matchPattern(step, m_Constant(&stepInt))) { 553 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) 554 return rewriter.create<vector::BroadcastOp>( 555 loc, mtp, rewriter.create<arith::ConstantIntOp>(loc, 1, 1)); 556 } 557 // Otherwise, generate a vector mask that avoids overrunning the upperbound 558 // during vector execution. Here we rely on subsequent loop optimizations to 559 // avoid executing the mask in all iterations, for example, by splitting the 560 // loop into an unconditional vector loop and a scalar cleanup loop. 561 auto minMap = AffineMap::get( 562 /*dimCount=*/2, /*symbolCount=*/1, 563 {rewriter.getAffineSymbolExpr(0), 564 rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)}, 565 rewriter.getContext()); 566 Value end = 567 rewriter.createOrFold<AffineMinOp>(loc, minMap, ValueRange{hi, iv, step}); 568 return rewriter.create<vector::CreateMaskOp>(loc, mtp, end); 569 } 570 571 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi]. 572 static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter, 573 Value ptr, ArrayRef<Value> args) { 574 Location loc = ptr.getLoc(); 575 VectorType vtp = vectorType(codegen, ptr); 576 Value pass = 577 rewriter.create<arith::ConstantOp>(loc, vtp, rewriter.getZeroAttr(vtp)); 578 if (args.back().getType().isa<VectorType>()) { 579 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 580 Value indexVec = args.back(); 581 scalarArgs.back() = rewriter.create<arith::ConstantIndexOp>(loc, 0); 582 return rewriter.create<vector::GatherOp>( 583 loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass); 584 } 585 return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args, 586 codegen.curVecMask, pass); 587 } 588 589 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs. 590 static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter, 591 Value rhs, Value ptr, ArrayRef<Value> args) { 592 Location loc = ptr.getLoc(); 593 if (args.back().getType().isa<VectorType>()) { 594 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 595 Value indexVec = args.back(); 596 scalarArgs.back() = rewriter.create<arith::ConstantIndexOp>(loc, 0); 597 rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, 598 codegen.curVecMask, rhs); 599 return; 600 } 601 rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask, 602 rhs); 603 } 604 605 /// Generates a vectorized invariant. Here we rely on subsequent loop 606 /// optimizations to hoist the invariant broadcast out of the vector loop. 607 static Value genVectorInvariantValue(CodeGen &codegen, 608 PatternRewriter &rewriter, Value val) { 609 VectorType vtp = vectorType(codegen, val.getType()); 610 return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val); 611 } 612 613 /// Generates an affine expression. 614 // 615 // TODO: generalize for sparse tensor subscripts 616 // 617 static Value genAffine(CodeGen &codegen, PatternRewriter &rewriter, 618 AffineExpr a, Location loc) { 619 switch (a.getKind()) { 620 case AffineExprKind::DimId: { 621 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 622 return codegen.loops[idx]; // universal dense index 623 } 624 case AffineExprKind::Add: { 625 auto binOp = a.cast<AffineBinaryOpExpr>(); 626 return rewriter.create<arith::AddIOp>( 627 loc, genAffine(codegen, rewriter, binOp.getLHS(), loc), 628 genAffine(codegen, rewriter, binOp.getRHS(), loc)); 629 } 630 case AffineExprKind::Mul: { 631 auto binOp = a.cast<AffineBinaryOpExpr>(); 632 return rewriter.create<arith::MulIOp>( 633 loc, genAffine(codegen, rewriter, binOp.getLHS(), loc), 634 genAffine(codegen, rewriter, binOp.getRHS(), loc)); 635 } 636 case AffineExprKind::Constant: { 637 int64_t c = a.cast<AffineConstantExpr>().getValue(); 638 return rewriter.create<arith::ConstantIndexOp>(loc, c); 639 } 640 default: 641 llvm_unreachable("unexpected affine subscript"); 642 } 643 } 644 645 /// Generates subscript for load/store on a dense or sparse tensor. 646 static Value genSubscript(CodeGen &codegen, PatternRewriter &rewriter, 647 linalg::GenericOp op, OpOperand *t, 648 SmallVector<Value, 4> &args) { 649 unsigned tensor = t->getOperandNumber(); 650 auto map = op.getTiedIndexingMap(t); 651 auto enc = getSparseTensorEncoding(t->get().getType()); 652 unsigned rank = map.getNumResults(); 653 if (enc) { 654 // Note that currently, all sparse subscripts are simple. 655 // TODO: accept affine too? 656 AffineExpr a = map.getResult(perm(enc, rank - 1)); 657 assert(a.getKind() == AffineExprKind::DimId); 658 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 659 assert(codegen.pidxs[tensor][idx] != nullptr); 660 args.push_back(codegen.pidxs[tensor][idx]); // position index 661 } else { 662 for (unsigned d = 0; d < rank; d++) { 663 AffineExpr a = map.getResult(perm(enc, d)); 664 args.push_back(genAffine(codegen, rewriter, a, op.getLoc())); 665 } 666 } 667 return codegen.buffers[tensor]; 668 } 669 670 /// Generates a load on a dense or sparse tensor. 671 static Value genTensorLoad(Merger &merger, CodeGen &codegen, 672 PatternRewriter &rewriter, linalg::GenericOp op, 673 unsigned exp) { 674 // Test if the load was hoisted to a higher loop nest. 675 Value val = merger.exp(exp).val; 676 if (val) { 677 if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>()) 678 return genVectorInvariantValue(codegen, rewriter, val); 679 return val; 680 } 681 // Actual load. 682 SmallVector<Value, 4> args; 683 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 684 Value ptr = genSubscript(codegen, rewriter, op, t, args); 685 if (codegen.curVecLength > 1) 686 return genVectorLoad(codegen, rewriter, ptr, args); 687 return rewriter.create<memref::LoadOp>(op.getLoc(), ptr, args); 688 } 689 690 /// Generates a store on a dense or sparse tensor. 691 static void genTensorStore(Merger &merger, CodeGen &codegen, 692 PatternRewriter &rewriter, linalg::GenericOp op, 693 Value rhs) { 694 // Test if this is a scalarized reduction. 695 if (codegen.redVal) { 696 if (codegen.curVecLength > 1) 697 rhs = rewriter.create<SelectOp>(op.getLoc(), codegen.curVecMask, rhs, 698 codegen.redVal); 699 updateReduc(merger, codegen, rhs); 700 return; 701 } 702 // Actual store. 703 SmallVector<Value, 4> args; 704 OpOperand *t = op.getOutputOperand(0); 705 Value ptr = genSubscript(codegen, rewriter, op, t, args); 706 if (codegen.curVecLength > 1) 707 genVectorStore(codegen, rewriter, rhs, ptr, args); 708 else 709 rewriter.create<memref::StoreOp>(op.getLoc(), rhs, ptr, args); 710 } 711 712 /// Generates a pointer/index load from the sparse storage scheme. Narrower 713 /// data types need to be zero extended before casting the value into the 714 /// index type used for looping and indexing. 715 static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc, 716 Value ptr, Value s) { 717 // See https://llvm.org/docs/GetElementPtr.html for some background on 718 // the complications described below. 719 if (codegen.curVecLength > 1) { 720 // Since the index vector is used in a subsequent gather/scatter operations, 721 // which effectively defines an unsigned pointer + signed index, we must 722 // zero extend the vector to an index width. For 8-bit and 16-bit values, 723 // an 32-bit index width suffices. For 32-bit values, zero extending the 724 // elements into 64-bit loses some performance since the 32-bit indexed 725 // gather/scatter is more efficient than the 64-bit index variant (if the 726 // negative 32-bit index space is unused, the enableSIMDIndex32 flag can 727 // preserve this performance). For 64-bit values, there is no good way 728 // to state that the indices are unsigned, with creates the potential of 729 // incorrect address calculations in the unlikely case we need such 730 // extremely large offsets. 731 Type etp = ptr.getType().cast<MemRefType>().getElementType(); 732 Value vload = genVectorLoad(codegen, rewriter, ptr, {s}); 733 if (!etp.isa<IndexType>()) { 734 if (etp.getIntOrFloatBitWidth() < 32) 735 vload = rewriter.create<arith::ExtUIOp>( 736 loc, vload, vectorType(codegen, genIntType(rewriter, 32))); 737 else if (etp.getIntOrFloatBitWidth() < 64 && 738 !codegen.options.enableSIMDIndex32) 739 vload = rewriter.create<arith::ExtUIOp>( 740 loc, vload, vectorType(codegen, genIntType(rewriter, 64))); 741 } 742 return vload; 743 } 744 // For the scalar case, we simply zero extend narrower indices into 64-bit 745 // values before casting to index without a performance penalty. Here too, 746 // however, indices that already are 64-bit, in theory, cannot express the 747 // full range as explained above. 748 Value load = rewriter.create<memref::LoadOp>(loc, ptr, s); 749 if (!load.getType().isa<IndexType>()) { 750 if (load.getType().getIntOrFloatBitWidth() < 64) 751 load = 752 rewriter.create<arith::ExtUIOp>(loc, load, genIntType(rewriter, 64)); 753 load = 754 rewriter.create<arith::IndexCastOp>(loc, load, rewriter.getIndexType()); 755 } 756 return load; 757 } 758 759 /// Generates an invariant value. 760 static Value genInvariantValue(Merger &merger, CodeGen &codegen, 761 PatternRewriter &rewriter, unsigned exp) { 762 Value val = merger.exp(exp).val; 763 if (codegen.curVecLength > 1) 764 return genVectorInvariantValue(codegen, rewriter, val); 765 return val; 766 } 767 768 /// Generates an address computation "sz * p + i". 769 static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter, 770 Location loc, Value size, Value p, Value i) { 771 Value mul = rewriter.create<arith::MulIOp>(loc, size, p); 772 if (auto vtp = i.getType().dyn_cast<VectorType>()) { 773 Value inv = 774 rewriter.create<arith::IndexCastOp>(loc, mul, vtp.getElementType()); 775 mul = genVectorInvariantValue(codegen, rewriter, inv); 776 } 777 return rewriter.create<arith::AddIOp>(loc, mul, i); 778 } 779 780 /// Recursively generates tensor expression. 781 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 782 linalg::GenericOp op, unsigned exp) { 783 Location loc = op.getLoc(); 784 if (exp == -1u) 785 return Value(); 786 if (merger.exp(exp).kind == Kind::kTensor) 787 return genTensorLoad(merger, codegen, rewriter, op, exp); 788 if (merger.exp(exp).kind == Kind::kInvariant) 789 return genInvariantValue(merger, codegen, rewriter, exp); 790 Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0); 791 Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1); 792 return merger.buildExp(rewriter, loc, exp, v0, v1); 793 } 794 795 /// Determines if affine expression is invariant. 796 static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a, 797 unsigned ldx, bool &atLevel) { 798 switch (a.getKind()) { 799 case AffineExprKind::DimId: { 800 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 801 if (idx == ldx) 802 atLevel = true; 803 return codegen.loops[idx] != nullptr; // no longer in play? 804 } 805 case AffineExprKind::Add: 806 case AffineExprKind::Mul: { 807 auto binOp = a.cast<AffineBinaryOpExpr>(); 808 return isInvariantAffine(codegen, binOp.getLHS(), ldx, atLevel) && 809 isInvariantAffine(codegen, binOp.getRHS(), ldx, atLevel); 810 } 811 default: 812 return true; 813 } 814 } 815 816 /// Hoists loop invariant tensor loads for which indices have been exhausted. 817 static void genInvariants(Merger &merger, CodeGen &codegen, 818 PatternRewriter &rewriter, linalg::GenericOp op, 819 unsigned exp, unsigned ldx, bool atStart, 820 Kind last = Kind::kTensor) { 821 if (exp == -1u) 822 return; 823 if (merger.exp(exp).kind == Kind::kTensor) { 824 // Inspect tensor indices. 825 bool atLevel = ldx == -1u; 826 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 827 auto map = op.getTiedIndexingMap(t); 828 auto enc = getSparseTensorEncoding(t->get().getType()); 829 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 830 AffineExpr a = map.getResult(perm(enc, d)); 831 if (!isInvariantAffine(codegen, a, ldx, atLevel)) 832 return; // still in play 833 } 834 // All exhausted at this level (atLevel denotes exactly at this level). 835 if (!atLevel) 836 return; 837 OpOperand *lhs = op.getOutputOperand(0); 838 if (lhs == t) { 839 // Start or end a scalarized reduction 840 if (atStart) { 841 Value load = genTensorLoad(merger, codegen, rewriter, op, exp); 842 codegen.redKind = getReduction(last); 843 codegen.redExp = exp; 844 updateReduc(merger, codegen, load); 845 } else { 846 Value redVal = codegen.redVal; 847 updateReduc(merger, codegen, Value()); 848 codegen.redExp = -1u; 849 codegen.redKind = kNoReduc; 850 genTensorStore(merger, codegen, rewriter, op, redVal); 851 } 852 } else { 853 // Start or end loop invariant hoisting of a tensor load. 854 merger.exp(exp).val = 855 atStart ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value(); 856 } 857 } else if (merger.exp(exp).kind != Kind::kInvariant) { 858 // Traverse into the binary operations. Note that we only hoist 859 // tensor loads, since subsequent MLIR/LLVM passes know how to 860 // deal with all other kinds of derived loop invariants. 861 Kind last = merger.exp(exp).kind; 862 unsigned e0 = merger.exp(exp).children.e0; 863 unsigned e1 = merger.exp(exp).children.e1; 864 genInvariants(merger, codegen, rewriter, op, e0, ldx, atStart, last); 865 genInvariants(merger, codegen, rewriter, op, e1, ldx, atStart, last); 866 } 867 } 868 869 /// Generates initialization code for the subsequent loop sequence at 870 /// current index level. Returns true if the loop sequence needs to 871 /// maintain the universal index. 872 static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 873 linalg::GenericOp op, std::vector<unsigned> &topSort, 874 unsigned at, llvm::BitVector &inits) { 875 bool needsUniv = false; 876 Location loc = op.getLoc(); 877 unsigned idx = topSort[at]; 878 879 // Initialize sparse positions. 880 for (unsigned b = 0, be = inits.size(); b < be; b++) { 881 if (inits[b]) { 882 unsigned tensor = merger.tensor(b); 883 assert(idx == merger.index(b)); 884 if (merger.isDim(b, Dim::kSparse)) { 885 // Initialize sparse index. 886 unsigned pat = at; 887 for (; pat != 0; pat--) { 888 if (codegen.pidxs[tensor][topSort[pat - 1]]) 889 break; 890 } 891 Value ptr = codegen.pointers[tensor][idx]; 892 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 893 Value p0 = (pat == 0) ? rewriter.create<arith::ConstantIndexOp>(loc, 0) 894 : codegen.pidxs[tensor][topSort[pat - 1]]; 895 codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0); 896 Value p1 = rewriter.create<arith::AddIOp>(loc, p0, one); 897 codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1); 898 } else { 899 // Dense index still in play. 900 needsUniv = true; 901 } 902 } 903 } 904 905 // Initialize the universal dense index. 906 codegen.loops[idx] = rewriter.create<arith::ConstantIndexOp>(loc, 0); 907 return needsUniv; 908 } 909 910 /// Returns vectorization strategy. Any implicit inner loop in the Linalg 911 /// operation is a candidate. Whether it is actually converted to SIMD code 912 /// depends on the requested strategy. 913 static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) { 914 switch (codegen.options.vectorizationStrategy) { 915 case SparseVectorizationStrategy::kNone: 916 return false; 917 case SparseVectorizationStrategy::kDenseInnerLoop: 918 return isInner && !isSparse; 919 case SparseVectorizationStrategy::kAnyStorageInnerLoop: 920 return isInner; 921 } 922 llvm_unreachable("unexpected vectorization strategy"); 923 } 924 925 /// Returns parallelization strategy. Any implicit loop in the Linalg operation 926 /// that is marked "parallel" is a candidate. Whether it is actually converted 927 /// to a parallel operation depends on the requested strategy. 928 static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction, 929 bool isSparse, bool isVector) { 930 switch (codegen.options.parallelizationStrategy) { 931 case SparseParallelizationStrategy::kNone: 932 return false; 933 case SparseParallelizationStrategy::kDenseOuterLoop: 934 return isOuter && !isSparse && !isReduction && !isVector; 935 case SparseParallelizationStrategy::kAnyStorageOuterLoop: 936 return isOuter && !isReduction && !isVector; 937 case SparseParallelizationStrategy::kDenseAnyLoop: 938 return !isSparse && !isReduction && !isVector; 939 case SparseParallelizationStrategy::kAnyStorageAnyLoop: 940 return !isReduction && !isVector; 941 } 942 llvm_unreachable("unexpected parallelization strategy"); 943 } 944 945 /// Checks unit stride for dense tensors. The iteration graph may have ignored 946 /// dense access patterns in order to avoid cycles (sparse access patterns are 947 /// always placed innermost), but that means dense access has become strided. 948 /// This prevents effective vectorization. 949 static bool denseUnitStrides(Merger &merger, linalg::GenericOp op, 950 unsigned idx) { 951 for (OpOperand *t : op.getInputAndOutputOperands()) { 952 if (!getSparseTensorEncoding(t->get().getType())) { 953 auto map = op.getTiedIndexingMap(t); 954 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 955 AffineExpr a = map.getResult(d); 956 // Report non-unit stride if innermost index appears at an outer 957 // dimension (true non-unit stride) or if the innermost index appears 958 // in a compound subscript in the innermost dimension. Even if the 959 // latter is unit stride, it does not play well with scatter/gather. 960 // TODO: accept unit stride affine innermost like a[i,j+k+1]? 961 if (a.isFunctionOfDim(idx) && 962 ((d != rank - 1) || (a.getKind() != AffineExprKind::DimId))) 963 return false; 964 } 965 } 966 } 967 return true; 968 } 969 970 /// Generates a for-loop on a single index. 971 static Operation *genFor(Merger &merger, CodeGen &codegen, 972 PatternRewriter &rewriter, linalg::GenericOp op, 973 bool isOuter, bool isInner, unsigned idx, 974 llvm::BitVector &indices) { 975 unsigned fb = indices.find_first(); 976 unsigned tensor = merger.tensor(fb); 977 assert(idx == merger.index(fb)); 978 auto iteratorTypes = op.iterator_types().getValue(); 979 bool isReduction = isReductionIterator(iteratorTypes[idx]); 980 bool isSparse = merger.isDim(fb, Dim::kSparse); 981 bool isVector = isVectorFor(codegen, isInner, isSparse) && 982 denseUnitStrides(merger, op, idx); 983 bool isParallel = 984 isParallelFor(codegen, isOuter, isReduction, isSparse, isVector); 985 986 // Prepare vector length. 987 if (isVector) 988 codegen.curVecLength = codegen.options.vectorLength; 989 990 // Loop bounds and increment. 991 Location loc = op.getLoc(); 992 Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx]; 993 Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx]; 994 Value step = 995 rewriter.create<arith::ConstantIndexOp>(loc, codegen.curVecLength); 996 997 // Emit a parallel loop. 998 if (isParallel) { 999 assert(!isVector); 1000 scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step); 1001 if (isSparse) 1002 codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0]; 1003 else 1004 codegen.loops[idx] = parOp.getInductionVars()[0]; 1005 rewriter.setInsertionPointToStart(parOp.getBody()); 1006 return parOp; 1007 } 1008 1009 // Emit a sequential or vector loop. 1010 SmallVector<Value, 4> operands; 1011 if (codegen.redVal) { 1012 // In a vector loop, bring reduction into SIMD form, if not already. 1013 if (isVector && !codegen.redVal.getType().isa<VectorType>()) { 1014 VectorType vtp = vectorType(codegen, codegen.redVal.getType()); 1015 Value vred = genVectorReducInit(codegen, rewriter, loc, vtp); 1016 updateReduc(merger, codegen, vred); 1017 } 1018 operands.push_back(codegen.redVal); 1019 } 1020 scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands); 1021 if (codegen.redVal) 1022 updateReduc(merger, codegen, forOp.getRegionIterArgs().front()); 1023 // Assign induction variable to sparse or dense index. 1024 Value iv = forOp.getInductionVar(); 1025 if (isSparse) 1026 codegen.pidxs[tensor][idx] = iv; 1027 else 1028 codegen.loops[idx] = iv; 1029 rewriter.setInsertionPointToStart(forOp.getBody()); 1030 // Share vector iteration mask between all subsequent loads/stores. 1031 if (isVector) 1032 codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step); 1033 return forOp; 1034 } 1035 1036 /// Emit a while-loop for co-iteration over multiple indices. 1037 static Operation *genWhile(Merger &merger, CodeGen &codegen, 1038 PatternRewriter &rewriter, linalg::GenericOp op, 1039 unsigned idx, bool needsUniv, 1040 llvm::BitVector &indices) { 1041 SmallVector<Type, 4> types; 1042 SmallVector<Value, 4> operands; 1043 // Construct the while-loop with a parameter for each index. 1044 Type indexType = rewriter.getIndexType(); 1045 for (unsigned b = 0, be = indices.size(); b < be; b++) { 1046 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1047 unsigned tensor = merger.tensor(b); 1048 assert(idx == merger.index(b)); 1049 types.push_back(indexType); 1050 operands.push_back(codegen.pidxs[tensor][idx]); 1051 } 1052 } 1053 if (codegen.redVal) { 1054 types.push_back(codegen.redVal.getType()); 1055 operands.push_back(codegen.redVal); 1056 } 1057 if (needsUniv) { 1058 types.push_back(indexType); 1059 operands.push_back(codegen.loops[idx]); 1060 } 1061 assert(types.size() == operands.size()); 1062 Location loc = op.getLoc(); 1063 scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands); 1064 Block *before = rewriter.createBlock(&whileOp.before(), {}, types); 1065 Block *after = rewriter.createBlock(&whileOp.after(), {}, types); 1066 1067 // Build the "before" region, which effectively consists 1068 // of a conjunction of "i < upper" tests on all induction. 1069 rewriter.setInsertionPointToStart(&whileOp.before().front()); 1070 Value cond; 1071 unsigned o = 0; 1072 for (unsigned b = 0, be = indices.size(); b < be; b++) { 1073 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1074 unsigned tensor = merger.tensor(b); 1075 assert(idx == merger.index(b)); 1076 Value op1 = before->getArgument(o); 1077 Value op2 = codegen.highs[tensor][idx]; 1078 Value opc = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 1079 op1, op2); 1080 cond = cond ? rewriter.create<arith::AndIOp>(loc, cond, opc) : opc; 1081 codegen.pidxs[tensor][idx] = after->getArgument(o++); 1082 } 1083 } 1084 if (codegen.redVal) 1085 updateReduc(merger, codegen, after->getArgument(o++)); 1086 if (needsUniv) 1087 codegen.loops[idx] = after->getArgument(o++); 1088 assert(o == operands.size()); 1089 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 1090 rewriter.setInsertionPointToStart(&whileOp.after().front()); 1091 return whileOp; 1092 } 1093 1094 /// Generates a for-loop or a while-loop, depending on whether it implements 1095 /// singleton iteration or co-iteration over the given conjunction. 1096 static Operation *genLoop(Merger &merger, CodeGen &codegen, 1097 PatternRewriter &rewriter, linalg::GenericOp op, 1098 std::vector<unsigned> &topSort, unsigned at, 1099 bool needsUniv, llvm::BitVector &indices) { 1100 unsigned idx = topSort[at]; 1101 if (indices.count() == 1) { 1102 bool isOuter = at == 0; 1103 bool isInner = at == topSort.size() - 1; 1104 return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx, 1105 indices); 1106 } 1107 return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices); 1108 } 1109 1110 /// Generates the local variables for this loop, consisting of the sparse 1111 /// indices, restored universal dense index, and dense positions. 1112 static void genLocals(Merger &merger, CodeGen &codegen, 1113 PatternRewriter &rewriter, linalg::GenericOp op, 1114 std::vector<unsigned> &topSort, unsigned at, 1115 bool needsUniv, llvm::BitVector &locals) { 1116 Location loc = op.getLoc(); 1117 unsigned idx = topSort[at]; 1118 1119 // Initialize sparse indices. 1120 Value min; 1121 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1122 if (locals[b] && merger.isDim(b, Dim::kSparse)) { 1123 unsigned tensor = merger.tensor(b); 1124 assert(idx == merger.index(b)); 1125 Value ptr = codegen.indices[tensor][idx]; 1126 Value s = codegen.pidxs[tensor][idx]; 1127 Value load = genLoad(codegen, rewriter, loc, ptr, s); 1128 codegen.idxs[tensor][idx] = load; 1129 if (!needsUniv) { 1130 if (min) { 1131 Value cmp = rewriter.create<arith::CmpIOp>( 1132 loc, arith::CmpIPredicate::ult, load, min); 1133 min = rewriter.create<SelectOp>(loc, cmp, load, min); 1134 } else { 1135 min = load; 1136 } 1137 } 1138 } 1139 } 1140 1141 // Merge dense universal index over minimum. 1142 if (min) { 1143 assert(!needsUniv); 1144 codegen.loops[idx] = min; 1145 } 1146 1147 // Initialize dense positions. Note that we generate dense indices of the 1148 // output tensor unconditionally, since they may not appear in the lattice, 1149 // but may be needed for linearized codegen. 1150 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1151 if ((locals[b] || merger.isOutTensor(b, idx)) && 1152 merger.isDim(b, Dim::kDense)) { 1153 unsigned tensor = merger.tensor(b); 1154 assert(idx == merger.index(b)); 1155 unsigned pat = at; 1156 for (; pat != 0; pat--) 1157 if (codegen.pidxs[tensor][topSort[pat - 1]]) 1158 break; 1159 Value p = (pat == 0) ? rewriter.create<arith::ConstantIndexOp>(loc, 0) 1160 : codegen.pidxs[tensor][topSort[pat - 1]]; 1161 codegen.pidxs[tensor][idx] = genAddress( 1162 codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]); 1163 } 1164 } 1165 } 1166 1167 /// Generates the induction structure for a while-loop. 1168 static void genWhileInduction(Merger &merger, CodeGen &codegen, 1169 PatternRewriter &rewriter, linalg::GenericOp op, 1170 unsigned idx, bool needsUniv, 1171 llvm::BitVector &induction, 1172 scf::WhileOp whileOp) { 1173 Location loc = op.getLoc(); 1174 // Finalize each else branch of all if statements. 1175 if (codegen.redVal) { 1176 while (auto ifOp = dyn_cast_or_null<scf::IfOp>( 1177 rewriter.getInsertionBlock()->getParentOp())) { 1178 rewriter.create<scf::YieldOp>(loc, codegen.redVal); 1179 updateReduc(merger, codegen, ifOp.getResult(0)); 1180 rewriter.setInsertionPointAfter(ifOp); 1181 } 1182 } 1183 rewriter.setInsertionPointToEnd(&whileOp.after().front()); 1184 // Finalize the induction. Note that the induction could be performed 1185 // in the individual if-branches to avoid re-evaluating the conditions. 1186 // However, that would result in a rather elaborate forest of yield 1187 // instructions during code generation. Moreover, performing the induction 1188 // after the if-statements more closely resembles code generated by TACO. 1189 unsigned o = 0; 1190 SmallVector<Value, 4> operands; 1191 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 1192 for (unsigned b = 0, be = induction.size(); b < be; b++) { 1193 if (induction[b] && merger.isDim(b, Dim::kSparse)) { 1194 unsigned tensor = merger.tensor(b); 1195 assert(idx == merger.index(b)); 1196 Value op1 = codegen.idxs[tensor][idx]; 1197 Value op2 = codegen.loops[idx]; 1198 Value op3 = codegen.pidxs[tensor][idx]; 1199 Value cmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 1200 op1, op2); 1201 Value add = rewriter.create<arith::AddIOp>(loc, op3, one); 1202 operands.push_back(rewriter.create<SelectOp>(loc, cmp, add, op3)); 1203 codegen.pidxs[tensor][idx] = whileOp->getResult(o++); 1204 } 1205 } 1206 if (codegen.redVal) { 1207 operands.push_back(codegen.redVal); 1208 updateReduc(merger, codegen, whileOp->getResult(o++)); 1209 } 1210 if (needsUniv) { 1211 operands.push_back( 1212 rewriter.create<arith::AddIOp>(loc, codegen.loops[idx], one)); 1213 codegen.loops[idx] = whileOp->getResult(o++); 1214 } 1215 assert(o == operands.size()); 1216 rewriter.create<scf::YieldOp>(loc, operands); 1217 rewriter.setInsertionPointAfter(whileOp); 1218 } 1219 1220 /// Generates the induction structure for a for-loop. 1221 static void genForInduction(Merger &merger, CodeGen &codegen, 1222 PatternRewriter &rewriter, linalg::GenericOp op, 1223 Operation *loop) { 1224 Location loc = op.getLoc(); 1225 unsigned o = 0; 1226 SmallVector<Value, 4> operands; 1227 if (codegen.redVal) { 1228 operands.push_back(codegen.redVal); 1229 updateReduc(merger, codegen, loop->getResult(o++)); 1230 } 1231 assert(o == operands.size()); 1232 if (o > 0) 1233 rewriter.create<scf::YieldOp>(loc, operands); 1234 rewriter.setInsertionPointAfter(loop); 1235 } 1236 1237 /// Generates a single if-statement within a while-loop. 1238 static scf::IfOp genIf(Merger &merger, CodeGen &codegen, 1239 PatternRewriter &rewriter, linalg::GenericOp op, 1240 unsigned idx, llvm::BitVector &conditions) { 1241 Location loc = op.getLoc(); 1242 SmallVector<Type, 4> types; 1243 Value cond; 1244 for (unsigned b = 0, be = conditions.size(); b < be; b++) { 1245 if (conditions[b]) { 1246 unsigned tensor = merger.tensor(b); 1247 assert(idx == merger.index(b)); 1248 Value clause; 1249 if (merger.isDim(b, Dim::kSparse)) { 1250 Value op1 = codegen.idxs[tensor][idx]; 1251 Value op2 = codegen.loops[idx]; 1252 clause = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 1253 op1, op2); 1254 } else { 1255 clause = rewriter.create<arith::ConstantIntOp>(loc, 1, 1); // true 1256 } 1257 cond = cond ? rewriter.create<arith::AndIOp>(loc, cond, clause) : clause; 1258 } 1259 } 1260 if (codegen.redVal) 1261 types.push_back(codegen.redVal.getType()); 1262 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, types, cond, /*else=*/true); 1263 rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); 1264 return ifOp; 1265 } 1266 1267 /// Generates end of true branch of if-statement within a while-loop. 1268 static void endIf(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1269 linalg::GenericOp op, scf::IfOp ifOp, Value ifInput) { 1270 if (codegen.redVal) { 1271 rewriter.create<scf::YieldOp>(op.getLoc(), codegen.redVal); 1272 updateReduc(merger, codegen, ifInput); 1273 } 1274 rewriter.setInsertionPointToStart(&ifOp.elseRegion().front()); 1275 } 1276 1277 //===----------------------------------------------------------------------===// 1278 // Sparse compiler synthesis methods (loop sequence). 1279 //===----------------------------------------------------------------------===// 1280 1281 /// Starts a loop sequence at given level. Returns true if 1282 /// the universal loop index must be maintained at this level. 1283 static bool startLoopSeq(Merger &merger, CodeGen &codegen, 1284 PatternRewriter &rewriter, linalg::GenericOp op, 1285 std::vector<unsigned> &topSort, unsigned exp, 1286 unsigned at, unsigned idx, unsigned ldx, 1287 unsigned lts) { 1288 assert(codegen.curVecLength == 1); 1289 assert(!codegen.loops[idx]); 1290 // Emit invariants at this loop sequence level. 1291 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*atStart=*/true); 1292 // Emit further intitialization at this loop sequence level. 1293 unsigned l0 = merger.set(lts)[0]; 1294 bool needsUniv = 1295 genInit(merger, codegen, rewriter, op, topSort, at, merger.lat(l0).bits); 1296 // Maintain the universal index only if it is actually 1297 // consumed by a subsequent lattice point. 1298 if (needsUniv) { 1299 unsigned lsize = merger.set(lts).size(); 1300 for (unsigned i = 1; i < lsize; i++) { 1301 unsigned li = merger.set(lts)[i]; 1302 if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) 1303 return true; 1304 } 1305 } 1306 return false; 1307 } 1308 1309 /// Starts a single loop in current sequence. 1310 static Operation *startLoop(Merger &merger, CodeGen &codegen, 1311 PatternRewriter &rewriter, linalg::GenericOp op, 1312 std::vector<unsigned> &topSort, unsigned at, 1313 unsigned li, bool needsUniv) { 1314 assert(codegen.curVecLength == 1); 1315 // Emit the for/while-loop control. 1316 Operation *loop = genLoop(merger, codegen, rewriter, op, topSort, at, 1317 needsUniv, merger.lat(li).simple); 1318 // Emit the locals for this loop. 1319 genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, 1320 merger.lat(li).bits); 1321 return loop; 1322 } 1323 1324 /// Ends a single loop in current sequence. Returns new values for needsUniv. 1325 static bool endLoop(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1326 linalg::GenericOp op, Operation *loop, unsigned idx, 1327 unsigned li, bool needsUniv) { 1328 codegen.curVecLength = 1; 1329 // End a while-loop. 1330 if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) { 1331 genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv, 1332 merger.lat(li).bits, whileOp); 1333 return needsUniv; 1334 } 1335 // End a for-loop. 1336 genForInduction(merger, codegen, rewriter, op, loop); 1337 return false; 1338 } 1339 1340 /// Ends a loop sequence at given level. 1341 static void endLoopSeq(Merger &merger, CodeGen &codegen, 1342 PatternRewriter &rewriter, linalg::GenericOp op, 1343 unsigned exp, unsigned idx, unsigned ldx) { 1344 assert(codegen.curVecLength == 1); 1345 codegen.loops[idx] = Value(); 1346 // Bring a pending reduction back from SIMD form when sequence ends. 1347 if (codegen.redVal) 1348 if (auto vtp = codegen.redVal.getType().dyn_cast<VectorType>()) 1349 updateReduc(merger, codegen, 1350 genVectorReducEnd(codegen, rewriter, op.getLoc(), vtp)); 1351 // Unmark bookkeeping of invariants and loop index. 1352 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*atStart=*/false); 1353 } 1354 1355 /// Recursively generates code while computing iteration lattices in order 1356 /// to manage the complexity of implementing co-iteration over unions 1357 /// and intersections of sparse iterations spaces. 1358 static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1359 linalg::GenericOp op, std::vector<unsigned> &topSort, 1360 unsigned exp, unsigned at) { 1361 // At each leaf, assign remaining tensor (sub)expression to output tensor. 1362 if (at == topSort.size()) { 1363 Value rhs = genExp(merger, codegen, rewriter, op, exp); 1364 genTensorStore(merger, codegen, rewriter, op, rhs); 1365 return; 1366 } 1367 1368 // Construct iteration lattices for current loop index, with L0 at top. 1369 unsigned idx = topSort[at]; 1370 unsigned ldx = at == 0 ? -1u : topSort[at - 1]; 1371 unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx)); 1372 1373 // Start a loop sequence. 1374 bool needsUniv = startLoopSeq(merger, codegen, rewriter, op, topSort, exp, at, 1375 idx, ldx, lts); 1376 1377 // Emit a loop for every lattice point L0 >= Li in this loop sequence. 1378 unsigned lsize = merger.set(lts).size(); 1379 for (unsigned i = 0; i < lsize; i++) { 1380 // Start a loop. 1381 unsigned li = merger.set(lts)[i]; 1382 Operation *loop = 1383 startLoop(merger, codegen, rewriter, op, topSort, at, li, needsUniv); 1384 1385 // Visit all lattices points with Li >= Lj to generate the 1386 // loop-body, possibly with if statements for coiteration. 1387 Value ifInput = codegen.redVal; 1388 bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr; 1389 for (unsigned j = 0; j < lsize; j++) { 1390 unsigned lj = merger.set(lts)[j]; 1391 unsigned ej = merger.lat(lj).exp; 1392 if (li == lj || merger.latGT(li, lj)) { 1393 // Recurse into body of each branch. 1394 if (isWhile) { 1395 scf::IfOp ifOp = 1396 genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple); 1397 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1398 endIf(merger, codegen, rewriter, op, ifOp, ifInput); 1399 } else { 1400 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1401 } 1402 } 1403 } 1404 1405 // End a loop. 1406 needsUniv = 1407 endLoop(merger, codegen, rewriter, op, loop, idx, li, needsUniv); 1408 } 1409 1410 // End a loop sequence. 1411 endLoopSeq(merger, codegen, rewriter, op, exp, idx, ldx); 1412 } 1413 1414 /// Converts the result computed by the sparse kernel into the required form. 1415 static void genResult(Merger &merger, CodeGen &codegen, 1416 PatternRewriter &rewriter, linalg::GenericOp op) { 1417 Location loc = op.getLoc(); 1418 OpOperand *lhs = op.getOutputOperand(0); 1419 Type resType = lhs->get().getType(); 1420 unsigned tensor = lhs->getOperandNumber(); 1421 auto map = op.getTiedIndexingMap(lhs); 1422 auto enc = getSparseTensorEncoding(resType); 1423 Value result = codegen.buffers.back(); // value array 1424 if (enc) { 1425 // The sparse annotation unambigiously defines the arrays needed 1426 // to "reconstruct" the sparse tensor from the storage scheme 1427 // (even though lowering should never need this eventually). 1428 SmallVector<Value, 4> args; 1429 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 1430 AffineExpr a = map.getResult(perm(enc, d)); 1431 if (a.getKind() != AffineExprKind::DimId) 1432 continue; // compound 1433 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 1434 if (merger.isDim(tensor, idx, Dim::kSparse)) { 1435 args.push_back(codegen.pointers[tensor][idx]); 1436 args.push_back(codegen.indices[tensor][idx]); 1437 } 1438 } 1439 args.push_back(result); 1440 result = rewriter.create<ToTensorOp>(loc, resType, args); 1441 } else { 1442 // To "reconstruct" an non-annotated tensor, sipmly load it 1443 // from the bufferized value. 1444 result = rewriter.create<memref::TensorLoadOp>(loc, resType, result); 1445 } 1446 rewriter.replaceOp(op, result); 1447 } 1448 1449 //===----------------------------------------------------------------------===// 1450 // Sparse compiler rewriting methods. 1451 //===----------------------------------------------------------------------===// 1452 1453 namespace { 1454 1455 /// Sparse rewriting rule for generic Lingalg operation. 1456 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> { 1457 public: 1458 GenericOpSparsifier(MLIRContext *context, SparsificationOptions o) 1459 : OpRewritePattern<linalg::GenericOp>(context), options(o) {} 1460 1461 LogicalResult matchAndRewrite(linalg::GenericOp op, 1462 PatternRewriter &rewriter) const override { 1463 // Detects sparse annotations and translate the per-dimension sparsity 1464 // information for all tensors to loop indices in the kernel. 1465 assert(op.getNumOutputs() == 1); 1466 unsigned numTensors = op.getNumInputsAndOutputs(); 1467 unsigned numLoops = op.iterator_types().getValue().size(); 1468 Merger merger(numTensors, numLoops); 1469 if (!findSparseAnnotations(merger, op)) 1470 return failure(); 1471 1472 // Computes a topologically sorted iteration graph to ensure 1473 // tensors are visited in natural index order. Fails on cycles. 1474 // This assumes that higher-level passes have already put the 1475 // tensors in each tensor expression in a feasible order. 1476 std::vector<unsigned> topSort; 1477 if (!computeIterationGraph(merger, op, topSort, 1478 SortMask::kIncludeUndef | 1479 SortMask::kIncludeDense) && 1480 !computeIterationGraph(merger, op, topSort, SortMask::kIncludeUndef) && 1481 !computeIterationGraph(merger, op, topSort, SortMask::kIncludeDense) && 1482 !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly)) 1483 return failure(); 1484 1485 // Builds the tensor expression for the Linalg operation in SSA form. 1486 Optional<unsigned> optExp = merger.buildTensorExpFromLinalg(op); 1487 if (!optExp.hasValue()) 1488 return failure(); 1489 unsigned exp = optExp.getValue(); 1490 1491 // Rejects an inadmissable tensor expression. 1492 if (!isAdmissableTensorExp(merger, op, exp)) 1493 return failure(); 1494 1495 // Recursively generates code. 1496 CodeGen codegen(options, numTensors, numLoops); 1497 genBuffers(merger, codegen, rewriter, op); 1498 genStmt(merger, codegen, rewriter, op, topSort, exp, 0); 1499 genResult(merger, codegen, rewriter, op); 1500 return success(); 1501 } 1502 1503 private: 1504 /// Options to control sparse code generation. 1505 SparsificationOptions options; 1506 }; 1507 1508 } // namespace 1509 1510 /// Populates the given patterns list with rewriting rules required for 1511 /// the sparsification of linear algebra operations. 1512 void mlir::populateSparsificationPatterns( 1513 RewritePatternSet &patterns, const SparsificationOptions &options) { 1514 patterns.add<GenericOpSparsifier>(patterns.getContext(), options); 1515 } 1516