128ebb0b6SAart Bik //===- SparseTensorRewriting.cpp - Sparse tensor rewriting rules ----------===//
228ebb0b6SAart Bik //
328ebb0b6SAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
428ebb0b6SAart Bik // See https://llvm.org/LICENSE.txt for license information.
528ebb0b6SAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
628ebb0b6SAart Bik //
728ebb0b6SAart Bik //===----------------------------------------------------------------------===//
828ebb0b6SAart Bik //
928ebb0b6SAart Bik // This file implements rewriting rules that are specific to sparse tensors.
1028ebb0b6SAart Bik //
1128ebb0b6SAart Bik //===----------------------------------------------------------------------===//
1228ebb0b6SAart Bik
1328ebb0b6SAart Bik #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1428ebb0b6SAart Bik #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1528ebb0b6SAart Bik #include "mlir/Dialect/Linalg/IR/Linalg.h"
1628ebb0b6SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
1728ebb0b6SAart Bik #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
1828ebb0b6SAart Bik #include "mlir/Dialect/Tensor/IR/Tensor.h"
1928ebb0b6SAart Bik #include "mlir/IR/AffineMap.h"
2028ebb0b6SAart Bik #include "mlir/IR/Matchers.h"
2128ebb0b6SAart Bik #include "mlir/Support/LLVM.h"
2228ebb0b6SAart Bik
2328ebb0b6SAart Bik using namespace mlir;
2428ebb0b6SAart Bik using namespace mlir::bufferization;
2528ebb0b6SAart Bik using namespace mlir::linalg;
2628ebb0b6SAart Bik using namespace mlir::sparse_tensor;
2728ebb0b6SAart Bik
2828ebb0b6SAart Bik //===---------------------------------------------------------------------===//
2928ebb0b6SAart Bik // Helper methods for the actual rewriting rules.
3028ebb0b6SAart Bik //===---------------------------------------------------------------------===//
3128ebb0b6SAart Bik
3228ebb0b6SAart Bik // Helper to detect a sparse tensor type operand.
isSparseTensor(OpOperand * op)3328ebb0b6SAart Bik static bool isSparseTensor(OpOperand *op) {
3428ebb0b6SAart Bik if (auto enc = getSparseTensorEncoding(op->get().getType())) {
3528ebb0b6SAart Bik ArrayRef<SparseTensorEncodingAttr::DimLevelType> dimTypes =
3628ebb0b6SAart Bik enc.getDimLevelType();
3728ebb0b6SAart Bik for (auto dimType : dimTypes)
3828ebb0b6SAart Bik if (dimType == SparseTensorEncodingAttr::DimLevelType::Compressed)
3928ebb0b6SAart Bik return true; // at least one compressed
4028ebb0b6SAart Bik }
4128ebb0b6SAart Bik return false;
4228ebb0b6SAart Bik }
4328ebb0b6SAart Bik
4428ebb0b6SAart Bik // Helper method to find zero or empty initialization.
isEmptyInit(OpOperand * op)4528ebb0b6SAart Bik static bool isEmptyInit(OpOperand *op) {
4628ebb0b6SAart Bik Value val = op->get();
4728ebb0b6SAart Bik return matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat()) ||
4828ebb0b6SAart Bik val.getDefiningOp<InitTensorOp>() ||
4928ebb0b6SAart Bik val.getDefiningOp<AllocTensorOp>();
5028ebb0b6SAart Bik }
5128ebb0b6SAart Bik
5228ebb0b6SAart Bik // Helper to detect sampling operation.
isSampling(GenericOp op)5328ebb0b6SAart Bik static bool isSampling(GenericOp op) {
5428ebb0b6SAart Bik auto yieldOp = cast<linalg::YieldOp>(op.region().front().getTerminator());
5528ebb0b6SAart Bik if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
5628ebb0b6SAart Bik if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) {
5728ebb0b6SAart Bik // Both scalar input arguments used exactly once.
5828ebb0b6SAart Bik Value s1 = op.getBlock()->getArgument(0);
5928ebb0b6SAart Bik Value s2 = op.getBlock()->getArgument(1);
6028ebb0b6SAart Bik return (def->getOperand(0) == s1 && def->getOperand(1) == s2) ||
6128ebb0b6SAart Bik (def->getOperand(1) == s1 && def->getOperand(0) == s2);
6228ebb0b6SAart Bik }
6328ebb0b6SAart Bik }
6428ebb0b6SAart Bik return false;
6528ebb0b6SAart Bik }
6628ebb0b6SAart Bik
6728ebb0b6SAart Bik // Helper to detect chain of multiplications that do not involve x.
isMulChain(Value val,Value x)6828ebb0b6SAart Bik static bool isMulChain(Value val, Value x) {
6928ebb0b6SAart Bik if (auto arg = val.dyn_cast<BlockArgument>())
7028ebb0b6SAart Bik return arg != x;
7128ebb0b6SAart Bik if (auto *def = val.getDefiningOp()) {
7228ebb0b6SAart Bik if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def))
7328ebb0b6SAart Bik return isMulChain(def->getOperand(0), x) &&
7428ebb0b6SAart Bik isMulChain(def->getOperand(1), x);
7528ebb0b6SAart Bik }
7628ebb0b6SAart Bik return false;
7728ebb0b6SAart Bik }
7828ebb0b6SAart Bik
7928ebb0b6SAart Bik // Helper to detect x = x + <multiplications>.
isSumOfMul(GenericOp op)8028ebb0b6SAart Bik static bool isSumOfMul(GenericOp op) {
8128ebb0b6SAart Bik auto yieldOp = cast<linalg::YieldOp>(op.region().front().getTerminator());
8228ebb0b6SAart Bik if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
8328ebb0b6SAart Bik if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def)) {
8428ebb0b6SAart Bik Value x = op.getBlock()->getArguments().back();
8528ebb0b6SAart Bik return (def->getOperand(0) == x && isMulChain(def->getOperand(1), x)) ||
8628ebb0b6SAart Bik (def->getOperand(1) == x && isMulChain(def->getOperand(0), x));
8728ebb0b6SAart Bik }
8828ebb0b6SAart Bik }
8928ebb0b6SAart Bik return false;
9028ebb0b6SAart Bik }
9128ebb0b6SAart Bik
9228ebb0b6SAart Bik //===---------------------------------------------------------------------===//
9328ebb0b6SAart Bik // The actual sparse tensor rewriting rules.
9428ebb0b6SAart Bik //===---------------------------------------------------------------------===//
9528ebb0b6SAart Bik
9628ebb0b6SAart Bik namespace {
9728ebb0b6SAart Bik
9828ebb0b6SAart Bik /// Rewriting rule that converts two kernels:
9928ebb0b6SAart Bik ///
10028ebb0b6SAart Bik /// T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... )
10128ebb0b6SAart Bik /// X(i,j) = S(i,j) * T(i,j)
10228ebb0b6SAart Bik ///
10328ebb0b6SAart Bik /// into a single kernel, using distributive law:
10428ebb0b6SAart Bik ///
10528ebb0b6SAart Bik /// X(i,j) = SUM(k, S(i,j) * A(i,j,k) * B(i,j,k) * ... )
10628ebb0b6SAart Bik ///
10728ebb0b6SAart Bik /// This kind of fusion (merging two ops into one but using arithmetic
10828ebb0b6SAart Bik /// equalities that may not hold for floating-point computations) would
10928ebb0b6SAart Bik /// be undesirable in the dense case, since we distribute the multiplication
11028ebb0b6SAart Bik /// into the reduction loop. However, for sparse sampling tensor S, such
11128ebb0b6SAart Bik /// a fusion may actually reduce the asymptotic complexity of the kernel,
11228ebb0b6SAart Bik /// since intermediate results may be nullified.
11328ebb0b6SAart Bik struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
11428ebb0b6SAart Bik public:
11528ebb0b6SAart Bik using OpRewritePattern<GenericOp>::OpRewritePattern;
11628ebb0b6SAart Bik
matchAndRewrite__anon72e83fd80111::FuseSparseMultiplyOverAdd11728ebb0b6SAart Bik LogicalResult matchAndRewrite(GenericOp op,
11828ebb0b6SAart Bik PatternRewriter &rewriter) const override {
11928ebb0b6SAart Bik // Check consumer.
12028ebb0b6SAart Bik if (!op.hasTensorSemantics() || op.getNumInputs() != 2 ||
12128ebb0b6SAart Bik op.getNumResults() != 1 ||
12228ebb0b6SAart Bik op.getNumParallelLoops() != op.getNumLoops() ||
12328ebb0b6SAart Bik !op.getTiedIndexingMap(op.getOutputOperand(0)).isIdentity() ||
12428ebb0b6SAart Bik !op.getTiedIndexingMap(op.getInputOperand(0)).isIdentity() ||
12528ebb0b6SAart Bik !op.getTiedIndexingMap(op.getInputOperand(1)).isIdentity())
12628ebb0b6SAart Bik return failure();
12728ebb0b6SAart Bik // Find consuming OP2(sparse, other) or OP2(other, sparse). The other
12828ebb0b6SAart Bik // operand can be sparse or dense, since the point of this rewriting rule
12928ebb0b6SAart Bik // is detecting a situation in which *more* sparsity is introduced into
13028ebb0b6SAart Bik // a computation, be it already sparse or still dense.
13128ebb0b6SAart Bik unsigned other = 0;
13228ebb0b6SAart Bik if (isSparseTensor(op.getInputOperand(0)))
13328ebb0b6SAart Bik other = 1;
13428ebb0b6SAart Bik else if (!isSparseTensor(op.getInputOperand(1)))
13528ebb0b6SAart Bik return failure();
13628ebb0b6SAart Bik // Check producer.
13728ebb0b6SAart Bik auto prod = dyn_cast_or_null<GenericOp>(
13828ebb0b6SAart Bik op.getInputOperand(other)->get().getDefiningOp());
13928ebb0b6SAart Bik if (!prod || !prod.hasTensorSemantics() || prod.getNumResults() != 1 ||
14028ebb0b6SAart Bik !prod.getResult(0).hasOneUse())
14128ebb0b6SAart Bik return failure();
14228ebb0b6SAart Bik // Sampling consumer and sum of multiplication chain producer.
14328ebb0b6SAart Bik if (!isEmptyInit(op.getOutputOperand(0)) ||
14428ebb0b6SAart Bik !isEmptyInit(prod.getOutputOperand(0)) || !isSampling(op) ||
14528ebb0b6SAart Bik !isSumOfMul(prod))
14628ebb0b6SAart Bik return failure();
14728ebb0b6SAart Bik // Modify operand structure of producer and consumer.
14828ebb0b6SAart Bik Location loc = prod.getLoc();
14928ebb0b6SAart Bik SmallVector<Value> inputOps = prod.getInputOperands();
15028ebb0b6SAart Bik SmallVector<Value> outputOps = op.getOutputOperands();
151*d2c0572bSJacques Pienaar SmallVector<AffineMap> fusedIndexMaps = prod.getIndexingMapsArray();
15228ebb0b6SAart Bik inputOps.push_back(op.getInputOperand(1 - other)->get());
15328ebb0b6SAart Bik fusedIndexMaps.push_back(fusedIndexMaps.back()); // mimic other
15428ebb0b6SAart Bik // Fuse producer and consumer into a new generic op.
15528ebb0b6SAart Bik auto fusedOp = rewriter.create<GenericOp>(
15628ebb0b6SAart Bik loc, op.getResult(0).getType(), inputOps, outputOps,
15728ebb0b6SAart Bik rewriter.getAffineMapArrayAttr(fusedIndexMaps), prod.iterator_types(),
15828ebb0b6SAart Bik /*doc=*/nullptr, /*library_call=*/nullptr);
15928ebb0b6SAart Bik Block &prodBlock = prod.region().front();
16028ebb0b6SAart Bik Block &consBlock = op.region().front();
16128ebb0b6SAart Bik BlockAndValueMapping mapper;
16228ebb0b6SAart Bik Block *fusedBlock = new Block();
16328ebb0b6SAart Bik fusedOp.region().push_back(fusedBlock);
16428ebb0b6SAart Bik unsigned num = prodBlock.getNumArguments();
16528ebb0b6SAart Bik for (unsigned i = 0; i < num - 1; i++)
16628ebb0b6SAart Bik addArg(mapper, fusedBlock, prodBlock.getArgument(i));
16728ebb0b6SAart Bik addArg(mapper, fusedBlock, consBlock.getArgument(1 - other));
16828ebb0b6SAart Bik addArg(mapper, fusedBlock, prodBlock.getArgument(num - 1));
16928ebb0b6SAart Bik // Clone bodies of the producer and consumer in new evaluation order.
17028ebb0b6SAart Bik auto *acc = prodBlock.getTerminator()->getOperand(0).getDefiningOp();
17128ebb0b6SAart Bik auto *sampler = consBlock.getTerminator()->getOperand(0).getDefiningOp();
17228ebb0b6SAart Bik rewriter.setInsertionPointToStart(fusedBlock);
17328ebb0b6SAart Bik Value last;
17428ebb0b6SAart Bik for (auto &op : prodBlock.without_terminator())
17528ebb0b6SAart Bik if (&op != acc) {
17628ebb0b6SAart Bik last = op.getResult(0);
17728ebb0b6SAart Bik rewriter.clone(op, mapper);
17828ebb0b6SAart Bik }
17928ebb0b6SAart Bik mapper.map(consBlock.getArgument(other), fusedBlock->back().getResult(0));
18028ebb0b6SAart Bik mapper.map(last, rewriter.clone(*sampler, mapper)->getResult(0));
18128ebb0b6SAart Bik last = rewriter.clone(*acc, mapper)->getResult(0);
18228ebb0b6SAart Bik rewriter.create<linalg::YieldOp>(loc, last);
18328ebb0b6SAart Bik // Replace consumer with fused operation. Old producer
18428ebb0b6SAart Bik // and consumer ops will be removed by DCE.
18528ebb0b6SAart Bik rewriter.replaceOp(op, fusedOp->getResults());
18628ebb0b6SAart Bik return success();
18728ebb0b6SAart Bik }
18828ebb0b6SAart Bik
18928ebb0b6SAart Bik private:
19028ebb0b6SAart Bik // Helper to add argument and record the mapping.
addArg__anon72e83fd80111::FuseSparseMultiplyOverAdd19128ebb0b6SAart Bik static void addArg(BlockAndValueMapping &mapper, Block *b, BlockArgument a) {
19228ebb0b6SAart Bik mapper.map(a, b->addArgument(a.getType(), a.getLoc()));
19328ebb0b6SAart Bik }
19428ebb0b6SAart Bik };
19528ebb0b6SAart Bik
19628ebb0b6SAart Bik /// Sparse rewriting rule for reshape operator.
19728ebb0b6SAart Bik template <typename ReshapeOp>
19828ebb0b6SAart Bik struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
19928ebb0b6SAart Bik public:
20028ebb0b6SAart Bik using OpRewritePattern<ReshapeOp>::OpRewritePattern;
20128ebb0b6SAart Bik
matchAndRewrite__anon72e83fd80111::ReshapeRewriter20228ebb0b6SAart Bik LogicalResult matchAndRewrite(ReshapeOp op,
20328ebb0b6SAart Bik PatternRewriter &rewriter) const override {
20428ebb0b6SAart Bik Location loc = op->getLoc();
20528ebb0b6SAart Bik auto encDst = getSparseTensorEncoding(op.getResult().getType());
20628ebb0b6SAart Bik auto encSrc = getSparseTensorEncoding(op.getSrc().getType());
20728ebb0b6SAart Bik // Since a pure dense expansion is very cheap (change of view), for
20828ebb0b6SAart Bik // a sparse2dense or dense2sparse, we can simply unfuse a sparse
20928ebb0b6SAart Bik // conversion from the reshape operation itself.
21028ebb0b6SAart Bik // All other cases are handled elsewhere.
21128ebb0b6SAart Bik if (encDst && encSrc) {
21228ebb0b6SAart Bik return failure();
21328ebb0b6SAart Bik } else if (encSrc) {
21428ebb0b6SAart Bik RankedTensorType rtp =
21528ebb0b6SAart Bik op.getSrc().getType().template cast<RankedTensorType>();
21628ebb0b6SAart Bik auto denseTp =
21728ebb0b6SAart Bik RankedTensorType::get(rtp.getShape(), rtp.getElementType());
21828ebb0b6SAart Bik auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
21928ebb0b6SAart Bik op->setOperand(0, convert);
22028ebb0b6SAart Bik return success();
22128ebb0b6SAart Bik } else if (encDst) {
22228ebb0b6SAart Bik RankedTensorType rtp =
22328ebb0b6SAart Bik op.getResult().getType().template cast<RankedTensorType>();
22428ebb0b6SAart Bik auto denseTp =
22528ebb0b6SAart Bik RankedTensorType::get(rtp.getShape(), rtp.getElementType());
22628ebb0b6SAart Bik auto reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
22728ebb0b6SAart Bik op.getReassociation());
22828ebb0b6SAart Bik Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape);
22928ebb0b6SAart Bik rewriter.replaceOp(op, convert);
23028ebb0b6SAart Bik return success();
23128ebb0b6SAart Bik }
23228ebb0b6SAart Bik return failure();
23328ebb0b6SAart Bik }
23428ebb0b6SAart Bik };
23528ebb0b6SAart Bik
23628ebb0b6SAart Bik } // namespace
23728ebb0b6SAart Bik
23828ebb0b6SAart Bik //===---------------------------------------------------------------------===//
23928ebb0b6SAart Bik // Methods that add patterns described in this file to a pattern list.
24028ebb0b6SAart Bik //===---------------------------------------------------------------------===//
24128ebb0b6SAart Bik
populateSparseTensorRewriting(RewritePatternSet & patterns)24228ebb0b6SAart Bik void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns) {
24328ebb0b6SAart Bik // TODO(springerm): enable FuseSparseMultiplyOverAdd
24428ebb0b6SAart Bik patterns.add<ReshapeRewriter<tensor::ExpandShapeOp>,
24528ebb0b6SAart Bik ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
24628ebb0b6SAart Bik }
247