1 //===- SparseTensorRewriting.cpp - Sparse tensor rewriting rules ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewriting rules that are specific to sparse tensors.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
15 #include "mlir/Dialect/Linalg/IR/Linalg.h"
16 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
17 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
18 #include "mlir/Dialect/Tensor/IR/Tensor.h"
19 #include "mlir/IR/AffineMap.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/Support/LLVM.h"
22
23 using namespace mlir;
24 using namespace mlir::bufferization;
25 using namespace mlir::linalg;
26 using namespace mlir::sparse_tensor;
27
28 //===---------------------------------------------------------------------===//
29 // Helper methods for the actual rewriting rules.
30 //===---------------------------------------------------------------------===//
31
32 // Helper to detect a sparse tensor type operand.
isSparseTensor(OpOperand * op)33 static bool isSparseTensor(OpOperand *op) {
34 if (auto enc = getSparseTensorEncoding(op->get().getType())) {
35 ArrayRef<SparseTensorEncodingAttr::DimLevelType> dimTypes =
36 enc.getDimLevelType();
37 for (auto dimType : dimTypes)
38 if (dimType == SparseTensorEncodingAttr::DimLevelType::Compressed)
39 return true; // at least one compressed
40 }
41 return false;
42 }
43
44 // Helper method to find zero or empty initialization.
isEmptyInit(OpOperand * op)45 static bool isEmptyInit(OpOperand *op) {
46 Value val = op->get();
47 return matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat()) ||
48 val.getDefiningOp<InitTensorOp>() ||
49 val.getDefiningOp<AllocTensorOp>();
50 }
51
52 // Helper to detect sampling operation.
isSampling(GenericOp op)53 static bool isSampling(GenericOp op) {
54 auto yieldOp = cast<linalg::YieldOp>(op.region().front().getTerminator());
55 if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
56 if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) {
57 // Both scalar input arguments used exactly once.
58 Value s1 = op.getBlock()->getArgument(0);
59 Value s2 = op.getBlock()->getArgument(1);
60 return (def->getOperand(0) == s1 && def->getOperand(1) == s2) ||
61 (def->getOperand(1) == s1 && def->getOperand(0) == s2);
62 }
63 }
64 return false;
65 }
66
67 // Helper to detect chain of multiplications that do not involve x.
isMulChain(Value val,Value x)68 static bool isMulChain(Value val, Value x) {
69 if (auto arg = val.dyn_cast<BlockArgument>())
70 return arg != x;
71 if (auto *def = val.getDefiningOp()) {
72 if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def))
73 return isMulChain(def->getOperand(0), x) &&
74 isMulChain(def->getOperand(1), x);
75 }
76 return false;
77 }
78
79 // Helper to detect x = x + <multiplications>.
isSumOfMul(GenericOp op)80 static bool isSumOfMul(GenericOp op) {
81 auto yieldOp = cast<linalg::YieldOp>(op.region().front().getTerminator());
82 if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
83 if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def)) {
84 Value x = op.getBlock()->getArguments().back();
85 return (def->getOperand(0) == x && isMulChain(def->getOperand(1), x)) ||
86 (def->getOperand(1) == x && isMulChain(def->getOperand(0), x));
87 }
88 }
89 return false;
90 }
91
92 //===---------------------------------------------------------------------===//
93 // The actual sparse tensor rewriting rules.
94 //===---------------------------------------------------------------------===//
95
96 namespace {
97
98 /// Rewriting rule that converts two kernels:
99 ///
100 /// T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... )
101 /// X(i,j) = S(i,j) * T(i,j)
102 ///
103 /// into a single kernel, using distributive law:
104 ///
105 /// X(i,j) = SUM(k, S(i,j) * A(i,j,k) * B(i,j,k) * ... )
106 ///
107 /// This kind of fusion (merging two ops into one but using arithmetic
108 /// equalities that may not hold for floating-point computations) would
109 /// be undesirable in the dense case, since we distribute the multiplication
110 /// into the reduction loop. However, for sparse sampling tensor S, such
111 /// a fusion may actually reduce the asymptotic complexity of the kernel,
112 /// since intermediate results may be nullified.
113 struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
114 public:
115 using OpRewritePattern<GenericOp>::OpRewritePattern;
116
matchAndRewrite__anon72e83fd80111::FuseSparseMultiplyOverAdd117 LogicalResult matchAndRewrite(GenericOp op,
118 PatternRewriter &rewriter) const override {
119 // Check consumer.
120 if (!op.hasTensorSemantics() || op.getNumInputs() != 2 ||
121 op.getNumResults() != 1 ||
122 op.getNumParallelLoops() != op.getNumLoops() ||
123 !op.getTiedIndexingMap(op.getOutputOperand(0)).isIdentity() ||
124 !op.getTiedIndexingMap(op.getInputOperand(0)).isIdentity() ||
125 !op.getTiedIndexingMap(op.getInputOperand(1)).isIdentity())
126 return failure();
127 // Find consuming OP2(sparse, other) or OP2(other, sparse). The other
128 // operand can be sparse or dense, since the point of this rewriting rule
129 // is detecting a situation in which *more* sparsity is introduced into
130 // a computation, be it already sparse or still dense.
131 unsigned other = 0;
132 if (isSparseTensor(op.getInputOperand(0)))
133 other = 1;
134 else if (!isSparseTensor(op.getInputOperand(1)))
135 return failure();
136 // Check producer.
137 auto prod = dyn_cast_or_null<GenericOp>(
138 op.getInputOperand(other)->get().getDefiningOp());
139 if (!prod || !prod.hasTensorSemantics() || prod.getNumResults() != 1 ||
140 !prod.getResult(0).hasOneUse())
141 return failure();
142 // Sampling consumer and sum of multiplication chain producer.
143 if (!isEmptyInit(op.getOutputOperand(0)) ||
144 !isEmptyInit(prod.getOutputOperand(0)) || !isSampling(op) ||
145 !isSumOfMul(prod))
146 return failure();
147 // Modify operand structure of producer and consumer.
148 Location loc = prod.getLoc();
149 SmallVector<Value> inputOps = prod.getInputOperands();
150 SmallVector<Value> outputOps = op.getOutputOperands();
151 SmallVector<AffineMap> fusedIndexMaps = prod.getIndexingMapsArray();
152 inputOps.push_back(op.getInputOperand(1 - other)->get());
153 fusedIndexMaps.push_back(fusedIndexMaps.back()); // mimic other
154 // Fuse producer and consumer into a new generic op.
155 auto fusedOp = rewriter.create<GenericOp>(
156 loc, op.getResult(0).getType(), inputOps, outputOps,
157 rewriter.getAffineMapArrayAttr(fusedIndexMaps), prod.iterator_types(),
158 /*doc=*/nullptr, /*library_call=*/nullptr);
159 Block &prodBlock = prod.region().front();
160 Block &consBlock = op.region().front();
161 BlockAndValueMapping mapper;
162 Block *fusedBlock = new Block();
163 fusedOp.region().push_back(fusedBlock);
164 unsigned num = prodBlock.getNumArguments();
165 for (unsigned i = 0; i < num - 1; i++)
166 addArg(mapper, fusedBlock, prodBlock.getArgument(i));
167 addArg(mapper, fusedBlock, consBlock.getArgument(1 - other));
168 addArg(mapper, fusedBlock, prodBlock.getArgument(num - 1));
169 // Clone bodies of the producer and consumer in new evaluation order.
170 auto *acc = prodBlock.getTerminator()->getOperand(0).getDefiningOp();
171 auto *sampler = consBlock.getTerminator()->getOperand(0).getDefiningOp();
172 rewriter.setInsertionPointToStart(fusedBlock);
173 Value last;
174 for (auto &op : prodBlock.without_terminator())
175 if (&op != acc) {
176 last = op.getResult(0);
177 rewriter.clone(op, mapper);
178 }
179 mapper.map(consBlock.getArgument(other), fusedBlock->back().getResult(0));
180 mapper.map(last, rewriter.clone(*sampler, mapper)->getResult(0));
181 last = rewriter.clone(*acc, mapper)->getResult(0);
182 rewriter.create<linalg::YieldOp>(loc, last);
183 // Replace consumer with fused operation. Old producer
184 // and consumer ops will be removed by DCE.
185 rewriter.replaceOp(op, fusedOp->getResults());
186 return success();
187 }
188
189 private:
190 // Helper to add argument and record the mapping.
addArg__anon72e83fd80111::FuseSparseMultiplyOverAdd191 static void addArg(BlockAndValueMapping &mapper, Block *b, BlockArgument a) {
192 mapper.map(a, b->addArgument(a.getType(), a.getLoc()));
193 }
194 };
195
196 /// Sparse rewriting rule for reshape operator.
197 template <typename ReshapeOp>
198 struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
199 public:
200 using OpRewritePattern<ReshapeOp>::OpRewritePattern;
201
matchAndRewrite__anon72e83fd80111::ReshapeRewriter202 LogicalResult matchAndRewrite(ReshapeOp op,
203 PatternRewriter &rewriter) const override {
204 Location loc = op->getLoc();
205 auto encDst = getSparseTensorEncoding(op.getResult().getType());
206 auto encSrc = getSparseTensorEncoding(op.getSrc().getType());
207 // Since a pure dense expansion is very cheap (change of view), for
208 // a sparse2dense or dense2sparse, we can simply unfuse a sparse
209 // conversion from the reshape operation itself.
210 // All other cases are handled elsewhere.
211 if (encDst && encSrc) {
212 return failure();
213 } else if (encSrc) {
214 RankedTensorType rtp =
215 op.getSrc().getType().template cast<RankedTensorType>();
216 auto denseTp =
217 RankedTensorType::get(rtp.getShape(), rtp.getElementType());
218 auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
219 op->setOperand(0, convert);
220 return success();
221 } else if (encDst) {
222 RankedTensorType rtp =
223 op.getResult().getType().template cast<RankedTensorType>();
224 auto denseTp =
225 RankedTensorType::get(rtp.getShape(), rtp.getElementType());
226 auto reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
227 op.getReassociation());
228 Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape);
229 rewriter.replaceOp(op, convert);
230 return success();
231 }
232 return failure();
233 }
234 };
235
236 } // namespace
237
238 //===---------------------------------------------------------------------===//
239 // Methods that add patterns described in this file to a pattern list.
240 //===---------------------------------------------------------------------===//
241
populateSparseTensorRewriting(RewritePatternSet & patterns)242 void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns) {
243 // TODO(springerm): enable FuseSparseMultiplyOverAdd
244 patterns.add<ReshapeRewriter<tensor::ExpandShapeOp>,
245 ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
246 }
247