1 //===- TosaDecomposeDepthwise.cpp -----------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Decompose TOSA Depthwise operation to a series of TOSA Ops specifically
10 // (1) Convert a 1x1 Depthwise to Reshape -> Mul -> Reshape -> Add
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
15 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
16 #include "mlir/Pass/Pass.h"
17 
18 using namespace mlir;
19 using namespace mlir::tosa;
20 
21 namespace {
22 
23 struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
DepthwiseConv2DIsMul__anonc4269ec50111::DepthwiseConv2DIsMul24   explicit DepthwiseConv2DIsMul(MLIRContext *context)
25       : OpRewritePattern(context) {}
26 
matchAndRewrite__anonc4269ec50111::DepthwiseConv2DIsMul27   LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op,
28                                 PatternRewriter &rewriter) const override {
29     Value input = op.getInput();
30     Value weight = op.getWeight();
31     ShapedType inputType = input.getType().cast<ShapedType>();
32     ShapedType weightType = weight.getType().cast<ShapedType>();
33     ShapedType resultType = op.getOutput().getType().cast<ShapedType>();
34     Type inputEType = inputType.getElementType();
35 
36     if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
37           resultType.hasStaticShape())) {
38       return failure();
39     }
40 
41     // Quantization information needs to still be performed.
42     if (op.getQuantizationInfo() || !inputEType.isa<FloatType>()) {
43       return failure();
44     }
45 
46     // Stride must be 1 for this optimization.
47     for (Attribute stride : op.getStride().getValue()) {
48       if (!stride.cast<IntegerAttr>().getValue().isOne()) {
49         return failure();
50       }
51     }
52 
53     // Only works for a 1x1 kernel.
54     ArrayRef<int64_t> weightShape = weightType.getShape();
55     if (weightShape[0] != 1 || weightShape[1] != 1) {
56       return failure();
57     }
58 
59     // Reshape input to [N, H, W, C] -> [N, H, W, C, 1].
60     ArrayRef<int64_t> inputShape = inputType.getShape();
61     llvm::SmallVector<int64_t, 2> revisedInputShape{
62         inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1};
63     auto revisedInputShapeType = RankedTensorType::get(
64         revisedInputShape,
65         input.getType().dyn_cast<RankedTensorType>().getElementType());
66     auto reshapedInput = rewriter
67                              .create<tosa::ReshapeOp>(
68                                  op.getLoc(), revisedInputShapeType, input,
69                                  rewriter.getI64ArrayAttr(revisedInputShape))
70                              .getResult();
71 
72     // Reshape kernel to [KH, KW, C, M] -> [1, 1, 1, C, M].
73     llvm::SmallVector<int64_t, 2> revisedWeightShape{1, 1, 1, weightShape[2],
74                                                      weightShape[3]};
75     auto revisedWeightShapeType = RankedTensorType::get(
76         revisedWeightShape,
77         weight.getType().dyn_cast<RankedTensorType>().getElementType());
78     auto reshapedWeight = rewriter
79                               .create<tosa::ReshapeOp>(
80                                   op.getLoc(), revisedWeightShapeType, weight,
81                                   rewriter.getI64ArrayAttr(revisedWeightShape))
82                               .getResult();
83 
84     // Perform an elementwise mul over the reshaped input and weight.
85     llvm::SmallVector<int64_t, 2> mulShape{inputShape[0], inputShape[1],
86                                            inputShape[2], inputShape[3],
87                                            weightShape[3]};
88     auto mulShapeType = RankedTensorType::get(
89         mulShape,
90         weight.getType().dyn_cast<RankedTensorType>().getElementType());
91     Value mulValue =
92         rewriter
93             .create<tosa::MulOp>(op.getLoc(), mulShapeType, reshapedInput,
94                                  reshapedWeight, /*shift=*/0)
95             .getResult();
96 
97     // Reshape output to [N, H, W, C * M].
98     auto outputShape = op.getOutput().getType().cast<ShapedType>().getShape();
99     auto outputShapeType = RankedTensorType::get(
100         outputShape,
101         input.getType().dyn_cast<RankedTensorType>().getElementType());
102     auto outputValue =
103         rewriter.create<tosa::ReshapeOp>(op.getLoc(), outputShapeType, mulValue,
104                                          rewriter.getI64ArrayAttr(outputShape));
105 
106     // Add in the bias.
107     rewriter
108         .replaceOpWithNewOp<tosa::AddOp>(op, outputShapeType, outputValue,
109                                          op.getBias())
110         .getResult();
111     return success();
112   }
113 };
114 
115 } // namespace
116 
populateTosaDecomposeDepthwise(MLIRContext * ctx,RewritePatternSet & patterns)117 void mlir::tosa::populateTosaDecomposeDepthwise(MLIRContext *ctx,
118                                                 RewritePatternSet &patterns) {
119   patterns.add<DepthwiseConv2DIsMul>(ctx);
120 }
121