128ebb0b6SAart Bik //===- SparseTensorRewriting.cpp - Sparse tensor rewriting rules ----------===//
228ebb0b6SAart Bik //
328ebb0b6SAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
428ebb0b6SAart Bik // See https://llvm.org/LICENSE.txt for license information.
528ebb0b6SAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
628ebb0b6SAart Bik //
728ebb0b6SAart Bik //===----------------------------------------------------------------------===//
828ebb0b6SAart Bik //
928ebb0b6SAart Bik // This file implements rewriting rules that are specific to sparse tensors.
1028ebb0b6SAart Bik //
1128ebb0b6SAart Bik //===----------------------------------------------------------------------===//
1228ebb0b6SAart Bik 
1328ebb0b6SAart Bik #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1428ebb0b6SAart Bik #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1528ebb0b6SAart Bik #include "mlir/Dialect/Linalg/IR/Linalg.h"
1628ebb0b6SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
1728ebb0b6SAart Bik #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
1828ebb0b6SAart Bik #include "mlir/Dialect/Tensor/IR/Tensor.h"
1928ebb0b6SAart Bik #include "mlir/IR/AffineMap.h"
2028ebb0b6SAart Bik #include "mlir/IR/Matchers.h"
2128ebb0b6SAart Bik #include "mlir/Support/LLVM.h"
2228ebb0b6SAart Bik 
2328ebb0b6SAart Bik using namespace mlir;
2428ebb0b6SAart Bik using namespace mlir::bufferization;
2528ebb0b6SAart Bik using namespace mlir::linalg;
2628ebb0b6SAart Bik using namespace mlir::sparse_tensor;
2728ebb0b6SAart Bik 
2828ebb0b6SAart Bik //===---------------------------------------------------------------------===//
2928ebb0b6SAart Bik // Helper methods for the actual rewriting rules.
3028ebb0b6SAart Bik //===---------------------------------------------------------------------===//
3128ebb0b6SAart Bik 
3228ebb0b6SAart Bik // Helper to detect a sparse tensor type operand.
isSparseTensor(OpOperand * op)3328ebb0b6SAart Bik static bool isSparseTensor(OpOperand *op) {
3428ebb0b6SAart Bik   if (auto enc = getSparseTensorEncoding(op->get().getType())) {
3528ebb0b6SAart Bik     ArrayRef<SparseTensorEncodingAttr::DimLevelType> dimTypes =
3628ebb0b6SAart Bik         enc.getDimLevelType();
3728ebb0b6SAart Bik     for (auto dimType : dimTypes)
3828ebb0b6SAart Bik       if (dimType == SparseTensorEncodingAttr::DimLevelType::Compressed)
3928ebb0b6SAart Bik         return true; // at least one compressed
4028ebb0b6SAart Bik   }
4128ebb0b6SAart Bik   return false;
4228ebb0b6SAart Bik }
4328ebb0b6SAart Bik 
4428ebb0b6SAart Bik // Helper method to find zero or empty initialization.
isEmptyInit(OpOperand * op)4528ebb0b6SAart Bik static bool isEmptyInit(OpOperand *op) {
4628ebb0b6SAart Bik   Value val = op->get();
4728ebb0b6SAart Bik   return matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat()) ||
4828ebb0b6SAart Bik          val.getDefiningOp<InitTensorOp>() ||
4928ebb0b6SAart Bik          val.getDefiningOp<AllocTensorOp>();
5028ebb0b6SAart Bik }
5128ebb0b6SAart Bik 
5228ebb0b6SAart Bik // Helper to detect sampling operation.
isSampling(GenericOp op)5328ebb0b6SAart Bik static bool isSampling(GenericOp op) {
5428ebb0b6SAart Bik   auto yieldOp = cast<linalg::YieldOp>(op.region().front().getTerminator());
5528ebb0b6SAart Bik   if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
5628ebb0b6SAart Bik     if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) {
5728ebb0b6SAart Bik       // Both scalar input arguments used exactly once.
5828ebb0b6SAart Bik       Value s1 = op.getBlock()->getArgument(0);
5928ebb0b6SAart Bik       Value s2 = op.getBlock()->getArgument(1);
6028ebb0b6SAart Bik       return (def->getOperand(0) == s1 && def->getOperand(1) == s2) ||
6128ebb0b6SAart Bik              (def->getOperand(1) == s1 && def->getOperand(0) == s2);
6228ebb0b6SAart Bik     }
6328ebb0b6SAart Bik   }
6428ebb0b6SAart Bik   return false;
6528ebb0b6SAart Bik }
6628ebb0b6SAart Bik 
6728ebb0b6SAart Bik // Helper to detect chain of multiplications that do not involve x.
isMulChain(Value val,Value x)6828ebb0b6SAart Bik static bool isMulChain(Value val, Value x) {
6928ebb0b6SAart Bik   if (auto arg = val.dyn_cast<BlockArgument>())
7028ebb0b6SAart Bik     return arg != x;
7128ebb0b6SAart Bik   if (auto *def = val.getDefiningOp()) {
7228ebb0b6SAart Bik     if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def))
7328ebb0b6SAart Bik       return isMulChain(def->getOperand(0), x) &&
7428ebb0b6SAart Bik              isMulChain(def->getOperand(1), x);
7528ebb0b6SAart Bik   }
7628ebb0b6SAart Bik   return false;
7728ebb0b6SAart Bik }
7828ebb0b6SAart Bik 
7928ebb0b6SAart Bik // Helper to detect x = x + <multiplications>.
isSumOfMul(GenericOp op)8028ebb0b6SAart Bik static bool isSumOfMul(GenericOp op) {
8128ebb0b6SAart Bik   auto yieldOp = cast<linalg::YieldOp>(op.region().front().getTerminator());
8228ebb0b6SAart Bik   if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
8328ebb0b6SAart Bik     if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def)) {
8428ebb0b6SAart Bik       Value x = op.getBlock()->getArguments().back();
8528ebb0b6SAart Bik       return (def->getOperand(0) == x && isMulChain(def->getOperand(1), x)) ||
8628ebb0b6SAart Bik              (def->getOperand(1) == x && isMulChain(def->getOperand(0), x));
8728ebb0b6SAart Bik     }
8828ebb0b6SAart Bik   }
8928ebb0b6SAart Bik   return false;
9028ebb0b6SAart Bik }
9128ebb0b6SAart Bik 
9228ebb0b6SAart Bik //===---------------------------------------------------------------------===//
9328ebb0b6SAart Bik // The actual sparse tensor rewriting rules.
9428ebb0b6SAart Bik //===---------------------------------------------------------------------===//
9528ebb0b6SAart Bik 
9628ebb0b6SAart Bik namespace {
9728ebb0b6SAart Bik 
9828ebb0b6SAart Bik /// Rewriting rule that converts two kernels:
9928ebb0b6SAart Bik ///
10028ebb0b6SAart Bik ///      T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... )
10128ebb0b6SAart Bik ///      X(i,j) = S(i,j) * T(i,j)
10228ebb0b6SAart Bik ///
10328ebb0b6SAart Bik /// into a single kernel, using distributive law:
10428ebb0b6SAart Bik ///
10528ebb0b6SAart Bik ///      X(i,j) = SUM(k, S(i,j) * A(i,j,k) * B(i,j,k) * ... )
10628ebb0b6SAart Bik ///
10728ebb0b6SAart Bik /// This kind of fusion (merging two ops into one but using arithmetic
10828ebb0b6SAart Bik /// equalities that may not hold for floating-point computations) would
10928ebb0b6SAart Bik /// be undesirable in the dense case, since we distribute the multiplication
11028ebb0b6SAart Bik /// into the reduction loop. However, for sparse sampling tensor S, such
11128ebb0b6SAart Bik /// a fusion may actually reduce the asymptotic complexity of the kernel,
11228ebb0b6SAart Bik /// since intermediate results may be nullified.
11328ebb0b6SAart Bik struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
11428ebb0b6SAart Bik public:
11528ebb0b6SAart Bik   using OpRewritePattern<GenericOp>::OpRewritePattern;
11628ebb0b6SAart Bik 
matchAndRewrite__anon72e83fd80111::FuseSparseMultiplyOverAdd11728ebb0b6SAart Bik   LogicalResult matchAndRewrite(GenericOp op,
11828ebb0b6SAart Bik                                 PatternRewriter &rewriter) const override {
11928ebb0b6SAart Bik     // Check consumer.
12028ebb0b6SAart Bik     if (!op.hasTensorSemantics() || op.getNumInputs() != 2 ||
12128ebb0b6SAart Bik         op.getNumResults() != 1 ||
12228ebb0b6SAart Bik         op.getNumParallelLoops() != op.getNumLoops() ||
12328ebb0b6SAart Bik         !op.getTiedIndexingMap(op.getOutputOperand(0)).isIdentity() ||
12428ebb0b6SAart Bik         !op.getTiedIndexingMap(op.getInputOperand(0)).isIdentity() ||
12528ebb0b6SAart Bik         !op.getTiedIndexingMap(op.getInputOperand(1)).isIdentity())
12628ebb0b6SAart Bik       return failure();
12728ebb0b6SAart Bik     // Find consuming OP2(sparse, other) or OP2(other, sparse). The other
12828ebb0b6SAart Bik     // operand can be sparse or dense, since the point of this rewriting rule
12928ebb0b6SAart Bik     // is detecting a situation in which *more* sparsity is introduced into
13028ebb0b6SAart Bik     // a computation, be it already sparse or still dense.
13128ebb0b6SAart Bik     unsigned other = 0;
13228ebb0b6SAart Bik     if (isSparseTensor(op.getInputOperand(0)))
13328ebb0b6SAart Bik       other = 1;
13428ebb0b6SAart Bik     else if (!isSparseTensor(op.getInputOperand(1)))
13528ebb0b6SAart Bik       return failure();
13628ebb0b6SAart Bik     // Check producer.
13728ebb0b6SAart Bik     auto prod = dyn_cast_or_null<GenericOp>(
13828ebb0b6SAart Bik         op.getInputOperand(other)->get().getDefiningOp());
13928ebb0b6SAart Bik     if (!prod || !prod.hasTensorSemantics() || prod.getNumResults() != 1 ||
14028ebb0b6SAart Bik         !prod.getResult(0).hasOneUse())
14128ebb0b6SAart Bik       return failure();
14228ebb0b6SAart Bik     // Sampling consumer and sum of multiplication chain producer.
14328ebb0b6SAart Bik     if (!isEmptyInit(op.getOutputOperand(0)) ||
14428ebb0b6SAart Bik         !isEmptyInit(prod.getOutputOperand(0)) || !isSampling(op) ||
14528ebb0b6SAart Bik         !isSumOfMul(prod))
14628ebb0b6SAart Bik       return failure();
14728ebb0b6SAart Bik     // Modify operand structure of producer and consumer.
14828ebb0b6SAart Bik     Location loc = prod.getLoc();
14928ebb0b6SAart Bik     SmallVector<Value> inputOps = prod.getInputOperands();
15028ebb0b6SAart Bik     SmallVector<Value> outputOps = op.getOutputOperands();
151*d2c0572bSJacques Pienaar     SmallVector<AffineMap> fusedIndexMaps = prod.getIndexingMapsArray();
15228ebb0b6SAart Bik     inputOps.push_back(op.getInputOperand(1 - other)->get());
15328ebb0b6SAart Bik     fusedIndexMaps.push_back(fusedIndexMaps.back()); // mimic other
15428ebb0b6SAart Bik     // Fuse producer and consumer into a new generic op.
15528ebb0b6SAart Bik     auto fusedOp = rewriter.create<GenericOp>(
15628ebb0b6SAart Bik         loc, op.getResult(0).getType(), inputOps, outputOps,
15728ebb0b6SAart Bik         rewriter.getAffineMapArrayAttr(fusedIndexMaps), prod.iterator_types(),
15828ebb0b6SAart Bik         /*doc=*/nullptr, /*library_call=*/nullptr);
15928ebb0b6SAart Bik     Block &prodBlock = prod.region().front();
16028ebb0b6SAart Bik     Block &consBlock = op.region().front();
16128ebb0b6SAart Bik     BlockAndValueMapping mapper;
16228ebb0b6SAart Bik     Block *fusedBlock = new Block();
16328ebb0b6SAart Bik     fusedOp.region().push_back(fusedBlock);
16428ebb0b6SAart Bik     unsigned num = prodBlock.getNumArguments();
16528ebb0b6SAart Bik     for (unsigned i = 0; i < num - 1; i++)
16628ebb0b6SAart Bik       addArg(mapper, fusedBlock, prodBlock.getArgument(i));
16728ebb0b6SAart Bik     addArg(mapper, fusedBlock, consBlock.getArgument(1 - other));
16828ebb0b6SAart Bik     addArg(mapper, fusedBlock, prodBlock.getArgument(num - 1));
16928ebb0b6SAart Bik     // Clone bodies of the producer and consumer in new evaluation order.
17028ebb0b6SAart Bik     auto *acc = prodBlock.getTerminator()->getOperand(0).getDefiningOp();
17128ebb0b6SAart Bik     auto *sampler = consBlock.getTerminator()->getOperand(0).getDefiningOp();
17228ebb0b6SAart Bik     rewriter.setInsertionPointToStart(fusedBlock);
17328ebb0b6SAart Bik     Value last;
17428ebb0b6SAart Bik     for (auto &op : prodBlock.without_terminator())
17528ebb0b6SAart Bik       if (&op != acc) {
17628ebb0b6SAart Bik         last = op.getResult(0);
17728ebb0b6SAart Bik         rewriter.clone(op, mapper);
17828ebb0b6SAart Bik       }
17928ebb0b6SAart Bik     mapper.map(consBlock.getArgument(other), fusedBlock->back().getResult(0));
18028ebb0b6SAart Bik     mapper.map(last, rewriter.clone(*sampler, mapper)->getResult(0));
18128ebb0b6SAart Bik     last = rewriter.clone(*acc, mapper)->getResult(0);
18228ebb0b6SAart Bik     rewriter.create<linalg::YieldOp>(loc, last);
18328ebb0b6SAart Bik     // Replace consumer with fused operation. Old producer
18428ebb0b6SAart Bik     // and consumer ops will be removed by DCE.
18528ebb0b6SAart Bik     rewriter.replaceOp(op, fusedOp->getResults());
18628ebb0b6SAart Bik     return success();
18728ebb0b6SAart Bik   }
18828ebb0b6SAart Bik 
18928ebb0b6SAart Bik private:
19028ebb0b6SAart Bik   // Helper to add argument and record the mapping.
addArg__anon72e83fd80111::FuseSparseMultiplyOverAdd19128ebb0b6SAart Bik   static void addArg(BlockAndValueMapping &mapper, Block *b, BlockArgument a) {
19228ebb0b6SAart Bik     mapper.map(a, b->addArgument(a.getType(), a.getLoc()));
19328ebb0b6SAart Bik   }
19428ebb0b6SAart Bik };
19528ebb0b6SAart Bik 
19628ebb0b6SAart Bik /// Sparse rewriting rule for reshape operator.
19728ebb0b6SAart Bik template <typename ReshapeOp>
19828ebb0b6SAart Bik struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
19928ebb0b6SAart Bik public:
20028ebb0b6SAart Bik   using OpRewritePattern<ReshapeOp>::OpRewritePattern;
20128ebb0b6SAart Bik 
matchAndRewrite__anon72e83fd80111::ReshapeRewriter20228ebb0b6SAart Bik   LogicalResult matchAndRewrite(ReshapeOp op,
20328ebb0b6SAart Bik                                 PatternRewriter &rewriter) const override {
20428ebb0b6SAart Bik     Location loc = op->getLoc();
20528ebb0b6SAart Bik     auto encDst = getSparseTensorEncoding(op.getResult().getType());
20628ebb0b6SAart Bik     auto encSrc = getSparseTensorEncoding(op.getSrc().getType());
20728ebb0b6SAart Bik     // Since a pure dense expansion is very cheap (change of view), for
20828ebb0b6SAart Bik     // a sparse2dense or dense2sparse, we can simply unfuse a sparse
20928ebb0b6SAart Bik     // conversion from the reshape operation itself.
21028ebb0b6SAart Bik     // All other cases are handled elsewhere.
21128ebb0b6SAart Bik     if (encDst && encSrc) {
21228ebb0b6SAart Bik       return failure();
21328ebb0b6SAart Bik     } else if (encSrc) {
21428ebb0b6SAart Bik       RankedTensorType rtp =
21528ebb0b6SAart Bik           op.getSrc().getType().template cast<RankedTensorType>();
21628ebb0b6SAart Bik       auto denseTp =
21728ebb0b6SAart Bik           RankedTensorType::get(rtp.getShape(), rtp.getElementType());
21828ebb0b6SAart Bik       auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
21928ebb0b6SAart Bik       op->setOperand(0, convert);
22028ebb0b6SAart Bik       return success();
22128ebb0b6SAart Bik     } else if (encDst) {
22228ebb0b6SAart Bik       RankedTensorType rtp =
22328ebb0b6SAart Bik           op.getResult().getType().template cast<RankedTensorType>();
22428ebb0b6SAart Bik       auto denseTp =
22528ebb0b6SAart Bik           RankedTensorType::get(rtp.getShape(), rtp.getElementType());
22628ebb0b6SAart Bik       auto reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
22728ebb0b6SAart Bik                                                 op.getReassociation());
22828ebb0b6SAart Bik       Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape);
22928ebb0b6SAart Bik       rewriter.replaceOp(op, convert);
23028ebb0b6SAart Bik       return success();
23128ebb0b6SAart Bik     }
23228ebb0b6SAart Bik     return failure();
23328ebb0b6SAart Bik   }
23428ebb0b6SAart Bik };
23528ebb0b6SAart Bik 
23628ebb0b6SAart Bik } // namespace
23728ebb0b6SAart Bik 
23828ebb0b6SAart Bik //===---------------------------------------------------------------------===//
23928ebb0b6SAart Bik // Methods that add patterns described in this file to a pattern list.
24028ebb0b6SAart Bik //===---------------------------------------------------------------------===//
24128ebb0b6SAart Bik 
populateSparseTensorRewriting(RewritePatternSet & patterns)24228ebb0b6SAart Bik void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns) {
24328ebb0b6SAart Bik   // TODO(springerm): enable FuseSparseMultiplyOverAdd
24428ebb0b6SAart Bik   patterns.add<ReshapeRewriter<tensor::ExpandShapeOp>,
24528ebb0b6SAart Bik                ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
24628ebb0b6SAart Bik }
247