1 //===- TosaDecomposeDepthwise.cpp -----------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Decompose TOSA Depthwise operation to a series of TOSA Ops specifically
10 // (1) Convert a 1x1 Depthwise to Reshape -> Mul -> Reshape -> Add
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
15 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
16 #include "mlir/Pass/Pass.h"
17
18 using namespace mlir;
19 using namespace mlir::tosa;
20
21 namespace {
22
23 struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
DepthwiseConv2DIsMul__anonc4269ec50111::DepthwiseConv2DIsMul24 explicit DepthwiseConv2DIsMul(MLIRContext *context)
25 : OpRewritePattern(context) {}
26
matchAndRewrite__anonc4269ec50111::DepthwiseConv2DIsMul27 LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op,
28 PatternRewriter &rewriter) const override {
29 Value input = op.getInput();
30 Value weight = op.getWeight();
31 ShapedType inputType = input.getType().cast<ShapedType>();
32 ShapedType weightType = weight.getType().cast<ShapedType>();
33 ShapedType resultType = op.getOutput().getType().cast<ShapedType>();
34 Type inputEType = inputType.getElementType();
35
36 if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
37 resultType.hasStaticShape())) {
38 return failure();
39 }
40
41 // Quantization information needs to still be performed.
42 if (op.getQuantizationInfo() || !inputEType.isa<FloatType>()) {
43 return failure();
44 }
45
46 // Stride must be 1 for this optimization.
47 for (Attribute stride : op.getStride().getValue()) {
48 if (!stride.cast<IntegerAttr>().getValue().isOne()) {
49 return failure();
50 }
51 }
52
53 // Only works for a 1x1 kernel.
54 ArrayRef<int64_t> weightShape = weightType.getShape();
55 if (weightShape[0] != 1 || weightShape[1] != 1) {
56 return failure();
57 }
58
59 // Reshape input to [N, H, W, C] -> [N, H, W, C, 1].
60 ArrayRef<int64_t> inputShape = inputType.getShape();
61 llvm::SmallVector<int64_t, 2> revisedInputShape{
62 inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1};
63 auto revisedInputShapeType = RankedTensorType::get(
64 revisedInputShape,
65 input.getType().dyn_cast<RankedTensorType>().getElementType());
66 auto reshapedInput = rewriter
67 .create<tosa::ReshapeOp>(
68 op.getLoc(), revisedInputShapeType, input,
69 rewriter.getI64ArrayAttr(revisedInputShape))
70 .getResult();
71
72 // Reshape kernel to [KH, KW, C, M] -> [1, 1, 1, C, M].
73 llvm::SmallVector<int64_t, 2> revisedWeightShape{1, 1, 1, weightShape[2],
74 weightShape[3]};
75 auto revisedWeightShapeType = RankedTensorType::get(
76 revisedWeightShape,
77 weight.getType().dyn_cast<RankedTensorType>().getElementType());
78 auto reshapedWeight = rewriter
79 .create<tosa::ReshapeOp>(
80 op.getLoc(), revisedWeightShapeType, weight,
81 rewriter.getI64ArrayAttr(revisedWeightShape))
82 .getResult();
83
84 // Perform an elementwise mul over the reshaped input and weight.
85 llvm::SmallVector<int64_t, 2> mulShape{inputShape[0], inputShape[1],
86 inputShape[2], inputShape[3],
87 weightShape[3]};
88 auto mulShapeType = RankedTensorType::get(
89 mulShape,
90 weight.getType().dyn_cast<RankedTensorType>().getElementType());
91 Value mulValue =
92 rewriter
93 .create<tosa::MulOp>(op.getLoc(), mulShapeType, reshapedInput,
94 reshapedWeight, /*shift=*/0)
95 .getResult();
96
97 // Reshape output to [N, H, W, C * M].
98 auto outputShape = op.getOutput().getType().cast<ShapedType>().getShape();
99 auto outputShapeType = RankedTensorType::get(
100 outputShape,
101 input.getType().dyn_cast<RankedTensorType>().getElementType());
102 auto outputValue =
103 rewriter.create<tosa::ReshapeOp>(op.getLoc(), outputShapeType, mulValue,
104 rewriter.getI64ArrayAttr(outputShape));
105
106 // Add in the bias.
107 rewriter
108 .replaceOpWithNewOp<tosa::AddOp>(op, outputShapeType, outputValue,
109 op.getBias())
110 .getResult();
111 return success();
112 }
113 };
114
115 } // namespace
116
populateTosaDecomposeDepthwise(MLIRContext * ctx,RewritePatternSet & patterns)117 void mlir::tosa::populateTosaDecomposeDepthwise(MLIRContext *ctx,
118 RewritePatternSet &patterns) {
119 patterns.add<DepthwiseConv2DIsMul>(ctx);
120 }
121