1 //===- TosaFoldConstantTranspose.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Fold TOSA Transpose operation on constant data
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
14 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
15 #include "mlir/IR/Matchers.h"
16 #include "mlir/Pass/Pass.h"
17
18 using namespace mlir;
19 using namespace mlir::tosa;
20
21 namespace {
22
23 struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
24 using OpRewritePattern::OpRewritePattern;
25
matchAndRewrite__anon4c0f76470111::TosaFoldConstantTranspose26 LogicalResult matchAndRewrite(tosa::TransposeOp op,
27 PatternRewriter &rewriter) const override {
28 auto outputType = op.getType().cast<ShapedType>();
29 // TOSA supports quantized types.
30 if (!outputType.getElementType().isIntOrIndexOrFloat())
31 return failure();
32
33 DenseElementsAttr inputValues;
34 if (!matchPattern(op.getInput1(), m_Constant(&inputValues)))
35 return failure();
36 // Make sure the input is a constant that has a single user.
37 if (!llvm::hasSingleElement(op.getInput1().getDefiningOp()->getUsers()))
38 return failure();
39
40 DenseIntElementsAttr permAttr;
41 if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
42 return failure();
43 auto permValues = llvm::to_vector<6>(llvm::map_range(
44 // TOSA allows both 32- and 64-bit integer tensors here.
45 permAttr.getValues<APInt>(),
46 [](const APInt &val) { return val.getZExtValue(); }));
47
48 auto inputType = op.getInput1().getType().cast<ShapedType>();
49 ArrayRef<int64_t> inputShape = inputType.getShape();
50 int64_t numElements = inputType.getNumElements();
51
52 SmallVector<Attribute, 4> outputValues;
53 outputValues.resize(numElements);
54
55 // Transpose the input constant. Because we don't know its rank in advance,
56 // we need to loop over the range [0, element count) and delinearize the
57 // index.
58 auto attrValues = inputValues.getValues<Attribute>();
59 ArrayRef<int64_t> outputShape = outputType.getShape();
60 for (int srcLinearIndex = 0; srcLinearIndex < numElements;
61 ++srcLinearIndex) {
62 SmallVector<uint64_t, 6> srcIndices(inputType.getRank(), 0);
63 int totalCount = srcLinearIndex;
64 for (int dim = inputType.getRank() - 1; dim >= 0; --dim) {
65 srcIndices[dim] = totalCount % inputShape[dim];
66 totalCount /= inputShape[dim];
67 }
68
69 SmallVector<uint64_t, 6> dstIndices(outputType.getRank(), 0);
70 for (int dim = outputType.getRank() - 1; dim >= 0; --dim)
71 dstIndices[dim] = srcIndices[permValues[dim]];
72
73 uint64_t dstLinearIndex = dstIndices.front();
74 for (int dim = 1; dim < outputType.getRank(); ++dim)
75 dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
76
77 outputValues[dstLinearIndex] = attrValues[srcIndices];
78 }
79
80 rewriter.replaceOpWithNewOp<tosa::ConstOp>(
81 op, outputType, DenseElementsAttr::get(outputType, outputValues));
82 return success();
83 }
84 };
85
86 } // namespace
87
populateTosaFoldConstantTransposePatterns(MLIRContext * ctx,RewritePatternSet & patterns)88 void mlir::tosa::populateTosaFoldConstantTransposePatterns(
89 MLIRContext *ctx, RewritePatternSet &patterns) {
90 patterns.add<TosaFoldConstantTranspose>(ctx);
91 }
92