1 //===- SparseTensorRewriting.cpp - Sparse tensor rewriting rules ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewriting rules that are specific to sparse tensors.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
15 #include "mlir/Dialect/Linalg/IR/Linalg.h"
16 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
17 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
18 #include "mlir/Dialect/Tensor/IR/Tensor.h"
19 #include "mlir/IR/AffineMap.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/Support/LLVM.h"
22 
23 using namespace mlir;
24 using namespace mlir::bufferization;
25 using namespace mlir::linalg;
26 using namespace mlir::sparse_tensor;
27 
28 //===---------------------------------------------------------------------===//
29 // Helper methods for the actual rewriting rules.
30 //===---------------------------------------------------------------------===//
31 
32 // Helper to detect a sparse tensor type operand.
isSparseTensor(OpOperand * op)33 static bool isSparseTensor(OpOperand *op) {
34   if (auto enc = getSparseTensorEncoding(op->get().getType())) {
35     ArrayRef<SparseTensorEncodingAttr::DimLevelType> dimTypes =
36         enc.getDimLevelType();
37     for (auto dimType : dimTypes)
38       if (dimType == SparseTensorEncodingAttr::DimLevelType::Compressed)
39         return true; // at least one compressed
40   }
41   return false;
42 }
43 
44 // Helper method to find zero or empty initialization.
isEmptyInit(OpOperand * op)45 static bool isEmptyInit(OpOperand *op) {
46   Value val = op->get();
47   return matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat()) ||
48          val.getDefiningOp<InitTensorOp>() ||
49          val.getDefiningOp<AllocTensorOp>();
50 }
51 
52 // Helper to detect sampling operation.
isSampling(GenericOp op)53 static bool isSampling(GenericOp op) {
54   auto yieldOp = cast<linalg::YieldOp>(op.region().front().getTerminator());
55   if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
56     if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) {
57       // Both scalar input arguments used exactly once.
58       Value s1 = op.getBlock()->getArgument(0);
59       Value s2 = op.getBlock()->getArgument(1);
60       return (def->getOperand(0) == s1 && def->getOperand(1) == s2) ||
61              (def->getOperand(1) == s1 && def->getOperand(0) == s2);
62     }
63   }
64   return false;
65 }
66 
67 // Helper to detect chain of multiplications that do not involve x.
isMulChain(Value val,Value x)68 static bool isMulChain(Value val, Value x) {
69   if (auto arg = val.dyn_cast<BlockArgument>())
70     return arg != x;
71   if (auto *def = val.getDefiningOp()) {
72     if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def))
73       return isMulChain(def->getOperand(0), x) &&
74              isMulChain(def->getOperand(1), x);
75   }
76   return false;
77 }
78 
79 // Helper to detect x = x + <multiplications>.
isSumOfMul(GenericOp op)80 static bool isSumOfMul(GenericOp op) {
81   auto yieldOp = cast<linalg::YieldOp>(op.region().front().getTerminator());
82   if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
83     if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def)) {
84       Value x = op.getBlock()->getArguments().back();
85       return (def->getOperand(0) == x && isMulChain(def->getOperand(1), x)) ||
86              (def->getOperand(1) == x && isMulChain(def->getOperand(0), x));
87     }
88   }
89   return false;
90 }
91 
92 //===---------------------------------------------------------------------===//
93 // The actual sparse tensor rewriting rules.
94 //===---------------------------------------------------------------------===//
95 
96 namespace {
97 
98 /// Rewriting rule that converts two kernels:
99 ///
100 ///      T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... )
101 ///      X(i,j) = S(i,j) * T(i,j)
102 ///
103 /// into a single kernel, using distributive law:
104 ///
105 ///      X(i,j) = SUM(k, S(i,j) * A(i,j,k) * B(i,j,k) * ... )
106 ///
107 /// This kind of fusion (merging two ops into one but using arithmetic
108 /// equalities that may not hold for floating-point computations) would
109 /// be undesirable in the dense case, since we distribute the multiplication
110 /// into the reduction loop. However, for sparse sampling tensor S, such
111 /// a fusion may actually reduce the asymptotic complexity of the kernel,
112 /// since intermediate results may be nullified.
113 struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
114 public:
115   using OpRewritePattern<GenericOp>::OpRewritePattern;
116 
matchAndRewrite__anon72e83fd80111::FuseSparseMultiplyOverAdd117   LogicalResult matchAndRewrite(GenericOp op,
118                                 PatternRewriter &rewriter) const override {
119     // Check consumer.
120     if (!op.hasTensorSemantics() || op.getNumInputs() != 2 ||
121         op.getNumResults() != 1 ||
122         op.getNumParallelLoops() != op.getNumLoops() ||
123         !op.getTiedIndexingMap(op.getOutputOperand(0)).isIdentity() ||
124         !op.getTiedIndexingMap(op.getInputOperand(0)).isIdentity() ||
125         !op.getTiedIndexingMap(op.getInputOperand(1)).isIdentity())
126       return failure();
127     // Find consuming OP2(sparse, other) or OP2(other, sparse). The other
128     // operand can be sparse or dense, since the point of this rewriting rule
129     // is detecting a situation in which *more* sparsity is introduced into
130     // a computation, be it already sparse or still dense.
131     unsigned other = 0;
132     if (isSparseTensor(op.getInputOperand(0)))
133       other = 1;
134     else if (!isSparseTensor(op.getInputOperand(1)))
135       return failure();
136     // Check producer.
137     auto prod = dyn_cast_or_null<GenericOp>(
138         op.getInputOperand(other)->get().getDefiningOp());
139     if (!prod || !prod.hasTensorSemantics() || prod.getNumResults() != 1 ||
140         !prod.getResult(0).hasOneUse())
141       return failure();
142     // Sampling consumer and sum of multiplication chain producer.
143     if (!isEmptyInit(op.getOutputOperand(0)) ||
144         !isEmptyInit(prod.getOutputOperand(0)) || !isSampling(op) ||
145         !isSumOfMul(prod))
146       return failure();
147     // Modify operand structure of producer and consumer.
148     Location loc = prod.getLoc();
149     SmallVector<Value> inputOps = prod.getInputOperands();
150     SmallVector<Value> outputOps = op.getOutputOperands();
151     SmallVector<AffineMap> fusedIndexMaps = prod.getIndexingMapsArray();
152     inputOps.push_back(op.getInputOperand(1 - other)->get());
153     fusedIndexMaps.push_back(fusedIndexMaps.back()); // mimic other
154     // Fuse producer and consumer into a new generic op.
155     auto fusedOp = rewriter.create<GenericOp>(
156         loc, op.getResult(0).getType(), inputOps, outputOps,
157         rewriter.getAffineMapArrayAttr(fusedIndexMaps), prod.iterator_types(),
158         /*doc=*/nullptr, /*library_call=*/nullptr);
159     Block &prodBlock = prod.region().front();
160     Block &consBlock = op.region().front();
161     BlockAndValueMapping mapper;
162     Block *fusedBlock = new Block();
163     fusedOp.region().push_back(fusedBlock);
164     unsigned num = prodBlock.getNumArguments();
165     for (unsigned i = 0; i < num - 1; i++)
166       addArg(mapper, fusedBlock, prodBlock.getArgument(i));
167     addArg(mapper, fusedBlock, consBlock.getArgument(1 - other));
168     addArg(mapper, fusedBlock, prodBlock.getArgument(num - 1));
169     // Clone bodies of the producer and consumer in new evaluation order.
170     auto *acc = prodBlock.getTerminator()->getOperand(0).getDefiningOp();
171     auto *sampler = consBlock.getTerminator()->getOperand(0).getDefiningOp();
172     rewriter.setInsertionPointToStart(fusedBlock);
173     Value last;
174     for (auto &op : prodBlock.without_terminator())
175       if (&op != acc) {
176         last = op.getResult(0);
177         rewriter.clone(op, mapper);
178       }
179     mapper.map(consBlock.getArgument(other), fusedBlock->back().getResult(0));
180     mapper.map(last, rewriter.clone(*sampler, mapper)->getResult(0));
181     last = rewriter.clone(*acc, mapper)->getResult(0);
182     rewriter.create<linalg::YieldOp>(loc, last);
183     // Replace consumer with fused operation. Old producer
184     // and consumer ops will be removed by DCE.
185     rewriter.replaceOp(op, fusedOp->getResults());
186     return success();
187   }
188 
189 private:
190   // Helper to add argument and record the mapping.
addArg__anon72e83fd80111::FuseSparseMultiplyOverAdd191   static void addArg(BlockAndValueMapping &mapper, Block *b, BlockArgument a) {
192     mapper.map(a, b->addArgument(a.getType(), a.getLoc()));
193   }
194 };
195 
196 /// Sparse rewriting rule for reshape operator.
197 template <typename ReshapeOp>
198 struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
199 public:
200   using OpRewritePattern<ReshapeOp>::OpRewritePattern;
201 
matchAndRewrite__anon72e83fd80111::ReshapeRewriter202   LogicalResult matchAndRewrite(ReshapeOp op,
203                                 PatternRewriter &rewriter) const override {
204     Location loc = op->getLoc();
205     auto encDst = getSparseTensorEncoding(op.getResult().getType());
206     auto encSrc = getSparseTensorEncoding(op.getSrc().getType());
207     // Since a pure dense expansion is very cheap (change of view), for
208     // a sparse2dense or dense2sparse, we can simply unfuse a sparse
209     // conversion from the reshape operation itself.
210     // All other cases are handled elsewhere.
211     if (encDst && encSrc) {
212       return failure();
213     } else if (encSrc) {
214       RankedTensorType rtp =
215           op.getSrc().getType().template cast<RankedTensorType>();
216       auto denseTp =
217           RankedTensorType::get(rtp.getShape(), rtp.getElementType());
218       auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
219       op->setOperand(0, convert);
220       return success();
221     } else if (encDst) {
222       RankedTensorType rtp =
223           op.getResult().getType().template cast<RankedTensorType>();
224       auto denseTp =
225           RankedTensorType::get(rtp.getShape(), rtp.getElementType());
226       auto reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
227                                                 op.getReassociation());
228       Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape);
229       rewriter.replaceOp(op, convert);
230       return success();
231     }
232     return failure();
233   }
234 };
235 
236 } // namespace
237 
238 //===---------------------------------------------------------------------===//
239 // Methods that add patterns described in this file to a pattern list.
240 //===---------------------------------------------------------------------===//
241 
populateSparseTensorRewriting(RewritePatternSet & patterns)242 void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns) {
243   // TODO(springerm): enable FuseSparseMultiplyOverAdd
244   patterns.add<ReshapeRewriter<tensor::ExpandShapeOp>,
245                ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
246 }
247