1 //===- TosaDecomposeDepthwise.cpp
2 //------------------------------------------===//
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Decompose TOSA Depthwise operation to a series of TOSA Ops specifically
11 // (1) Convert a 1x1 Depthwise to Reshape -> Mul -> Reshape -> Add
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
16 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
17 #include "mlir/Pass/Pass.h"
18 
19 using namespace mlir;
20 using namespace mlir::tosa;
21 
22 namespace {
23 
24 struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
25   explicit DepthwiseConv2DIsMul(MLIRContext *context)
26       : OpRewritePattern(context) {}
27 
28   LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op,
29                                 PatternRewriter &rewriter) const override {
30     Value input = op.input();
31     Value weight = op.weight();
32     ShapedType inputType = input.getType().cast<ShapedType>();
33     ShapedType weightType = weight.getType().cast<ShapedType>();
34     ShapedType resultType = op.output().getType().cast<ShapedType>();
35     Type inputEType = inputType.getElementType();
36 
37     if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
38           resultType.hasStaticShape())) {
39       return failure();
40     }
41 
42     // Quantization information needs to still be performed.
43     if (op.quantization_info() || !inputEType.isa<FloatType>()) {
44       return failure();
45     }
46 
47     // Stride must be 1 for this optimization.
48     for (Attribute stride : op.stride().getValue()) {
49       if (!stride.cast<IntegerAttr>().getValue().isOne()) {
50         return failure();
51       }
52     }
53 
54     // Only works for a 1x1 kernel.
55     ArrayRef<int64_t> weightShape = weightType.getShape();
56     if (weightShape[0] != 1 || weightShape[1] != 1) {
57       return failure();
58     }
59 
60     // Reshape input to [N, H, W, C] -> [N, H, W, C, 1].
61     ArrayRef<int64_t> inputShape = inputType.getShape();
62     llvm::SmallVector<int64_t, 2> revisedInputShape{
63         inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1};
64     auto revisedInputShapeType = RankedTensorType::get(
65         revisedInputShape,
66         input.getType().dyn_cast<RankedTensorType>().getElementType());
67     auto reshapedInput = rewriter
68                              .create<tosa::ReshapeOp>(
69                                  op.getLoc(), revisedInputShapeType, input,
70                                  rewriter.getI64ArrayAttr(revisedInputShape))
71                              .getResult();
72 
73     // Reshape kernel to [KH, KW, C, M] -> [1, 1, 1, C, M].
74     llvm::SmallVector<int64_t, 2> revisedWeightShape{1, 1, 1, weightShape[2],
75                                                      weightShape[3]};
76     auto revisedWeightShapeType = RankedTensorType::get(
77         revisedWeightShape,
78         weight.getType().dyn_cast<RankedTensorType>().getElementType());
79     auto reshapedWeight = rewriter
80                               .create<tosa::ReshapeOp>(
81                                   op.getLoc(), revisedWeightShapeType, weight,
82                                   rewriter.getI64ArrayAttr(revisedWeightShape))
83                               .getResult();
84 
85     // Perform an elementwise mul over the reshaped input and weight.
86     llvm::SmallVector<int64_t, 2> mulShape{inputShape[0], inputShape[1],
87                                            inputShape[2], inputShape[3],
88                                            weightShape[3]};
89     auto mulShapeType = RankedTensorType::get(
90         mulShape,
91         weight.getType().dyn_cast<RankedTensorType>().getElementType());
92     Value mulValue =
93         rewriter
94             .create<tosa::MulOp>(op.getLoc(), mulShapeType, reshapedInput,
95                                  reshapedWeight, /*shift=*/0)
96             .getResult();
97 
98     // Reshape output to [N, H, W, C * M].
99     auto outputShape = op.output().getType().cast<ShapedType>().getShape();
100     auto outputShapeType = RankedTensorType::get(
101         outputShape,
102         input.getType().dyn_cast<RankedTensorType>().getElementType());
103     auto outputValue =
104         rewriter.create<tosa::ReshapeOp>(op.getLoc(), outputShapeType, mulValue,
105                                          rewriter.getI64ArrayAttr(outputShape));
106 
107     // Add in the bias.
108     rewriter
109         .replaceOpWithNewOp<tosa::AddOp>(op, outputShapeType, outputValue,
110                                          op.bias())
111         .getResult();
112     return success();
113   }
114 };
115 
116 } // namespace
117 
118 void mlir::tosa::populateTosaDecomposeDepthwise(MLIRContext *ctx,
119                                                 RewritePatternSet &patterns) {
120   patterns.insert<DepthwiseConv2DIsMul>(ctx);
121 }
122