1 //===- TosaFoldConstantTranspose.cpp --------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Fold TOSA Transpose operation on constant data 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Tosa/IR/TosaOps.h" 14 #include "mlir/Dialect/Tosa/Transforms/Passes.h" 15 #include "mlir/IR/Matchers.h" 16 #include "mlir/Pass/Pass.h" 17 18 using namespace mlir; 19 using namespace mlir::tosa; 20 21 namespace { 22 23 struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> { 24 using OpRewritePattern::OpRewritePattern; 25 26 LogicalResult matchAndRewrite(tosa::TransposeOp op, 27 PatternRewriter &rewriter) const override { 28 auto outputType = op.getType().cast<ShapedType>(); 29 // TOSA supports quantized types. 30 if (!outputType.getElementType().isIntOrIndexOrFloat()) 31 return failure(); 32 33 DenseElementsAttr inputValues; 34 if (!matchPattern(op.input1(), m_Constant(&inputValues))) 35 return failure(); 36 // Make sure the input is a constant that has a single user. 37 if (!llvm::hasSingleElement(op.input1().getDefiningOp()->getUsers())) 38 return failure(); 39 40 DenseIntElementsAttr permAttr; 41 if (!matchPattern(op.perms(), m_Constant(&permAttr))) 42 return failure(); 43 auto permValues = llvm::to_vector<6>(llvm::map_range( 44 // TOSA allows both 32- and 64-bit integer tensors here. 45 permAttr.getValues<APInt>(), 46 [](const APInt &val) { return val.getZExtValue(); })); 47 48 auto inputType = op.input1().getType().cast<ShapedType>(); 49 ArrayRef<int64_t> inputShape = inputType.getShape(); 50 int64_t numElements = inputType.getNumElements(); 51 52 SmallVector<Attribute, 4> outputValues; 53 outputValues.resize(numElements); 54 55 // Transpose the input constant. Because we don't know its rank in advance, 56 // we need to loop over the range [0, element count) and delinearize the 57 // index. 58 auto attrValues = inputValues.getValues<Attribute>(); 59 ArrayRef<int64_t> outputShape = outputType.getShape(); 60 for (int srcLinearIndex = 0; srcLinearIndex < numElements; 61 ++srcLinearIndex) { 62 SmallVector<uint64_t, 6> srcIndices(inputType.getRank(), 0); 63 int totalCount = srcLinearIndex; 64 for (int dim = inputType.getRank() - 1; dim >= 0; --dim) { 65 srcIndices[dim] = totalCount % inputShape[dim]; 66 totalCount /= inputShape[dim]; 67 } 68 69 SmallVector<uint64_t, 6> dstIndices(outputType.getRank(), 0); 70 for (int dim = outputType.getRank() - 1; dim >= 0; --dim) 71 dstIndices[dim] = srcIndices[permValues[dim]]; 72 73 uint64_t dstLinearIndex = dstIndices.front(); 74 for (int dim = 1; dim < outputType.getRank(); ++dim) 75 dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim]; 76 77 outputValues[dstLinearIndex] = attrValues[srcIndices]; 78 } 79 80 rewriter.replaceOpWithNewOp<tosa::ConstOp>( 81 op, outputType, DenseElementsAttr::get(outputType, outputValues)); 82 return success(); 83 } 84 }; 85 86 } // namespace 87 88 void mlir::tosa::populateTosaFoldConstantTransposePatterns( 89 MLIRContext *ctx, RewritePatternSet &patterns) { 90 patterns.add<TosaFoldConstantTranspose>(ctx); 91 } 92