1 //===- TosaFoldConstantTranspose.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Fold TOSA Transpose operation on constant data
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
14 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
15 #include "mlir/IR/Matchers.h"
16 #include "mlir/Pass/Pass.h"
17 
18 using namespace mlir;
19 using namespace mlir::tosa;
20 
21 namespace {
22 
23 struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
24   using OpRewritePattern::OpRewritePattern;
25 
matchAndRewrite__anon4c0f76470111::TosaFoldConstantTranspose26   LogicalResult matchAndRewrite(tosa::TransposeOp op,
27                                 PatternRewriter &rewriter) const override {
28     auto outputType = op.getType().cast<ShapedType>();
29     // TOSA supports quantized types.
30     if (!outputType.getElementType().isIntOrIndexOrFloat())
31       return failure();
32 
33     DenseElementsAttr inputValues;
34     if (!matchPattern(op.getInput1(), m_Constant(&inputValues)))
35       return failure();
36     // Make sure the input is a constant that has a single user.
37     if (!llvm::hasSingleElement(op.getInput1().getDefiningOp()->getUsers()))
38       return failure();
39 
40     DenseIntElementsAttr permAttr;
41     if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
42       return failure();
43     auto permValues = llvm::to_vector<6>(llvm::map_range(
44         // TOSA allows both 32- and 64-bit integer tensors here.
45         permAttr.getValues<APInt>(),
46         [](const APInt &val) { return val.getZExtValue(); }));
47 
48     auto inputType = op.getInput1().getType().cast<ShapedType>();
49     ArrayRef<int64_t> inputShape = inputType.getShape();
50     int64_t numElements = inputType.getNumElements();
51 
52     SmallVector<Attribute, 4> outputValues;
53     outputValues.resize(numElements);
54 
55     // Transpose the input constant. Because we don't know its rank in advance,
56     // we need to loop over the range [0, element count) and delinearize the
57     // index.
58     auto attrValues = inputValues.getValues<Attribute>();
59     ArrayRef<int64_t> outputShape = outputType.getShape();
60     for (int srcLinearIndex = 0; srcLinearIndex < numElements;
61          ++srcLinearIndex) {
62       SmallVector<uint64_t, 6> srcIndices(inputType.getRank(), 0);
63       int totalCount = srcLinearIndex;
64       for (int dim = inputType.getRank() - 1; dim >= 0; --dim) {
65         srcIndices[dim] = totalCount % inputShape[dim];
66         totalCount /= inputShape[dim];
67       }
68 
69       SmallVector<uint64_t, 6> dstIndices(outputType.getRank(), 0);
70       for (int dim = outputType.getRank() - 1; dim >= 0; --dim)
71         dstIndices[dim] = srcIndices[permValues[dim]];
72 
73       uint64_t dstLinearIndex = dstIndices.front();
74       for (int dim = 1; dim < outputType.getRank(); ++dim)
75         dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
76 
77       outputValues[dstLinearIndex] = attrValues[srcIndices];
78     }
79 
80     rewriter.replaceOpWithNewOp<tosa::ConstOp>(
81         op, outputType, DenseElementsAttr::get(outputType, outputValues));
82     return success();
83   }
84 };
85 
86 } // namespace
87 
populateTosaFoldConstantTransposePatterns(MLIRContext * ctx,RewritePatternSet & patterns)88 void mlir::tosa::populateTosaFoldConstantTransposePatterns(
89     MLIRContext *ctx, RewritePatternSet &patterns) {
90   patterns.add<TosaFoldConstantTranspose>(ctx);
91 }
92