1 //===- SparseTensorRewriting.cpp - Sparse tensor rewriting rules ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements rewriting rules that are specific to sparse tensors. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 15 #include "mlir/Dialect/Linalg/IR/Linalg.h" 16 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 17 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 18 #include "mlir/Dialect/Tensor/IR/Tensor.h" 19 #include "mlir/IR/AffineMap.h" 20 #include "mlir/IR/Matchers.h" 21 #include "mlir/Support/LLVM.h" 22 23 using namespace mlir; 24 using namespace mlir::bufferization; 25 using namespace mlir::linalg; 26 using namespace mlir::sparse_tensor; 27 28 //===---------------------------------------------------------------------===// 29 // Helper methods for the actual rewriting rules. 30 //===---------------------------------------------------------------------===// 31 32 // Helper to detect a sparse tensor type operand. 33 static bool isSparseTensor(OpOperand *op) { 34 if (auto enc = getSparseTensorEncoding(op->get().getType())) { 35 ArrayRef<SparseTensorEncodingAttr::DimLevelType> dimTypes = 36 enc.getDimLevelType(); 37 for (auto dimType : dimTypes) 38 if (dimType == SparseTensorEncodingAttr::DimLevelType::Compressed) 39 return true; // at least one compressed 40 } 41 return false; 42 } 43 44 // Helper method to find zero or empty initialization. 45 static bool isEmptyInit(OpOperand *op) { 46 Value val = op->get(); 47 return matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat()) || 48 val.getDefiningOp<InitTensorOp>() || 49 val.getDefiningOp<AllocTensorOp>(); 50 } 51 52 // Helper to detect sampling operation. 53 static bool isSampling(GenericOp op) { 54 auto yieldOp = cast<linalg::YieldOp>(op.region().front().getTerminator()); 55 if (auto *def = yieldOp.getOperand(0).getDefiningOp()) { 56 if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) { 57 // Both scalar input arguments used exactly once. 58 Value s1 = op.getBlock()->getArgument(0); 59 Value s2 = op.getBlock()->getArgument(1); 60 return (def->getOperand(0) == s1 && def->getOperand(1) == s2) || 61 (def->getOperand(1) == s1 && def->getOperand(0) == s2); 62 } 63 } 64 return false; 65 } 66 67 // Helper to detect chain of multiplications that do not involve x. 68 static bool isMulChain(Value val, Value x) { 69 if (auto arg = val.dyn_cast<BlockArgument>()) 70 return arg != x; 71 if (auto *def = val.getDefiningOp()) { 72 if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) 73 return isMulChain(def->getOperand(0), x) && 74 isMulChain(def->getOperand(1), x); 75 } 76 return false; 77 } 78 79 // Helper to detect x = x + <multiplications>. 80 static bool isSumOfMul(GenericOp op) { 81 auto yieldOp = cast<linalg::YieldOp>(op.region().front().getTerminator()); 82 if (auto *def = yieldOp.getOperand(0).getDefiningOp()) { 83 if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def)) { 84 Value x = op.getBlock()->getArguments().back(); 85 return (def->getOperand(0) == x && isMulChain(def->getOperand(1), x)) || 86 (def->getOperand(1) == x && isMulChain(def->getOperand(0), x)); 87 } 88 } 89 return false; 90 } 91 92 //===---------------------------------------------------------------------===// 93 // The actual sparse tensor rewriting rules. 94 //===---------------------------------------------------------------------===// 95 96 namespace { 97 98 /// Rewriting rule that converts two kernels: 99 /// 100 /// T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... ) 101 /// X(i,j) = S(i,j) * T(i,j) 102 /// 103 /// into a single kernel, using distributive law: 104 /// 105 /// X(i,j) = SUM(k, S(i,j) * A(i,j,k) * B(i,j,k) * ... ) 106 /// 107 /// This kind of fusion (merging two ops into one but using arithmetic 108 /// equalities that may not hold for floating-point computations) would 109 /// be undesirable in the dense case, since we distribute the multiplication 110 /// into the reduction loop. However, for sparse sampling tensor S, such 111 /// a fusion may actually reduce the asymptotic complexity of the kernel, 112 /// since intermediate results may be nullified. 113 struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> { 114 public: 115 using OpRewritePattern<GenericOp>::OpRewritePattern; 116 117 LogicalResult matchAndRewrite(GenericOp op, 118 PatternRewriter &rewriter) const override { 119 // Check consumer. 120 if (!op.hasTensorSemantics() || op.getNumInputs() != 2 || 121 op.getNumResults() != 1 || 122 op.getNumParallelLoops() != op.getNumLoops() || 123 !op.getTiedIndexingMap(op.getOutputOperand(0)).isIdentity() || 124 !op.getTiedIndexingMap(op.getInputOperand(0)).isIdentity() || 125 !op.getTiedIndexingMap(op.getInputOperand(1)).isIdentity()) 126 return failure(); 127 // Find consuming OP2(sparse, other) or OP2(other, sparse). The other 128 // operand can be sparse or dense, since the point of this rewriting rule 129 // is detecting a situation in which *more* sparsity is introduced into 130 // a computation, be it already sparse or still dense. 131 unsigned other = 0; 132 if (isSparseTensor(op.getInputOperand(0))) 133 other = 1; 134 else if (!isSparseTensor(op.getInputOperand(1))) 135 return failure(); 136 // Check producer. 137 auto prod = dyn_cast_or_null<GenericOp>( 138 op.getInputOperand(other)->get().getDefiningOp()); 139 if (!prod || !prod.hasTensorSemantics() || prod.getNumResults() != 1 || 140 !prod.getResult(0).hasOneUse()) 141 return failure(); 142 // Sampling consumer and sum of multiplication chain producer. 143 if (!isEmptyInit(op.getOutputOperand(0)) || 144 !isEmptyInit(prod.getOutputOperand(0)) || !isSampling(op) || 145 !isSumOfMul(prod)) 146 return failure(); 147 // Modify operand structure of producer and consumer. 148 Location loc = prod.getLoc(); 149 SmallVector<Value> inputOps = prod.getInputOperands(); 150 SmallVector<Value> outputOps = op.getOutputOperands(); 151 SmallVector<AffineMap> fusedIndexMaps = prod.getIndexingMapsArray(); 152 inputOps.push_back(op.getInputOperand(1 - other)->get()); 153 fusedIndexMaps.push_back(fusedIndexMaps.back()); // mimic other 154 // Fuse producer and consumer into a new generic op. 155 auto fusedOp = rewriter.create<GenericOp>( 156 loc, op.getResult(0).getType(), inputOps, outputOps, 157 rewriter.getAffineMapArrayAttr(fusedIndexMaps), prod.iterator_types(), 158 /*doc=*/nullptr, /*library_call=*/nullptr); 159 Block &prodBlock = prod.region().front(); 160 Block &consBlock = op.region().front(); 161 BlockAndValueMapping mapper; 162 Block *fusedBlock = new Block(); 163 fusedOp.region().push_back(fusedBlock); 164 unsigned num = prodBlock.getNumArguments(); 165 for (unsigned i = 0; i < num - 1; i++) 166 addArg(mapper, fusedBlock, prodBlock.getArgument(i)); 167 addArg(mapper, fusedBlock, consBlock.getArgument(1 - other)); 168 addArg(mapper, fusedBlock, prodBlock.getArgument(num - 1)); 169 // Clone bodies of the producer and consumer in new evaluation order. 170 auto *acc = prodBlock.getTerminator()->getOperand(0).getDefiningOp(); 171 auto *sampler = consBlock.getTerminator()->getOperand(0).getDefiningOp(); 172 rewriter.setInsertionPointToStart(fusedBlock); 173 Value last; 174 for (auto &op : prodBlock.without_terminator()) 175 if (&op != acc) { 176 last = op.getResult(0); 177 rewriter.clone(op, mapper); 178 } 179 mapper.map(consBlock.getArgument(other), fusedBlock->back().getResult(0)); 180 mapper.map(last, rewriter.clone(*sampler, mapper)->getResult(0)); 181 last = rewriter.clone(*acc, mapper)->getResult(0); 182 rewriter.create<linalg::YieldOp>(loc, last); 183 // Replace consumer with fused operation. Old producer 184 // and consumer ops will be removed by DCE. 185 rewriter.replaceOp(op, fusedOp->getResults()); 186 return success(); 187 } 188 189 private: 190 // Helper to add argument and record the mapping. 191 static void addArg(BlockAndValueMapping &mapper, Block *b, BlockArgument a) { 192 mapper.map(a, b->addArgument(a.getType(), a.getLoc())); 193 } 194 }; 195 196 /// Sparse rewriting rule for reshape operator. 197 template <typename ReshapeOp> 198 struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> { 199 public: 200 using OpRewritePattern<ReshapeOp>::OpRewritePattern; 201 202 LogicalResult matchAndRewrite(ReshapeOp op, 203 PatternRewriter &rewriter) const override { 204 Location loc = op->getLoc(); 205 auto encDst = getSparseTensorEncoding(op.getResult().getType()); 206 auto encSrc = getSparseTensorEncoding(op.getSrc().getType()); 207 // Since a pure dense expansion is very cheap (change of view), for 208 // a sparse2dense or dense2sparse, we can simply unfuse a sparse 209 // conversion from the reshape operation itself. 210 // All other cases are handled elsewhere. 211 if (encDst && encSrc) { 212 return failure(); 213 } else if (encSrc) { 214 RankedTensorType rtp = 215 op.getSrc().getType().template cast<RankedTensorType>(); 216 auto denseTp = 217 RankedTensorType::get(rtp.getShape(), rtp.getElementType()); 218 auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc()); 219 op->setOperand(0, convert); 220 return success(); 221 } else if (encDst) { 222 RankedTensorType rtp = 223 op.getResult().getType().template cast<RankedTensorType>(); 224 auto denseTp = 225 RankedTensorType::get(rtp.getShape(), rtp.getElementType()); 226 auto reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(), 227 op.getReassociation()); 228 Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape); 229 rewriter.replaceOp(op, convert); 230 return success(); 231 } 232 return failure(); 233 } 234 }; 235 236 } // namespace 237 238 //===---------------------------------------------------------------------===// 239 // Methods that add patterns described in this file to a pattern list. 240 //===---------------------------------------------------------------------===// 241 242 void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns) { 243 // TODO(springerm): enable FuseSparseMultiplyOverAdd 244 patterns.add<ReshapeRewriter<tensor::ExpandShapeOp>, 245 ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext()); 246 } 247