1 //===- TosaDecomposeConv2D.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Decompose TOSA Conv2D operation to a series of TOSA Ops specifically
10 // (1) Convert a 1x1 Convolution to a Reshape->FC->Reshape
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
15 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
16 #include "mlir/Pass/Pass.h"
17 
18 using namespace mlir;
19 using namespace mlir::tosa;
20 
21 namespace {
22 
23 struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
24   explicit Conv2DIsFullyConnected(MLIRContext *context)
25       : OpRewritePattern(context) {}
26 
27   LogicalResult matchAndRewrite(tosa::Conv2DOp op,
28                                 PatternRewriter &rewriter) const override {
29     Value input = op.input();
30     Value weight = op.weight();
31     ShapedType inputType = input.getType().cast<ShapedType>();
32     ShapedType weightType = weight.getType().cast<ShapedType>();
33     ShapedType resultType = op.getType().cast<ShapedType>();
34 
35     if (!inputType.hasStaticShape() || !weightType.hasRank()) {
36       return failure();
37     }
38 
39     // Stride must be 1 for this optimization.
40     for (Attribute stride : op.stride().getValue()) {
41       if (!stride.cast<IntegerAttr>().getValue().isOne()) {
42         return failure();
43       }
44     }
45 
46     // Only works for a 1x1 kernel.
47     ArrayRef<int64_t> weightShape = weightType.getShape();
48     if (weightShape[1] != 1 || weightShape[2] != 1) {
49       return failure();
50     }
51 
52     // Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC].
53     ArrayRef<int64_t> inputShape = inputType.getShape();
54     llvm::SmallVector<int64_t, 2> revisedInputShape{
55         inputShape[0] * inputShape[1] * inputShape[2], inputShape[3]};
56     auto revisedInputShapeType = RankedTensorType::get(
57         revisedInputShape,
58         input.getType().dyn_cast<RankedTensorType>().getElementType());
59     auto reshapedInput = rewriter
60                              .create<tosa::ReshapeOp>(
61                                  op.getLoc(), revisedInputShapeType, input,
62                                  rewriter.getI64ArrayAttr(revisedInputShape))
63                              .getResult();
64 
65     // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC].
66     llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0],
67                                                      weightShape[3]};
68     auto revisedWeightShapeType = RankedTensorType::get(
69         revisedWeightShape,
70         weight.getType().dyn_cast<RankedTensorType>().getElementType());
71     auto reshapedWeight = rewriter
72                               .create<tosa::ReshapeOp>(
73                                   op.getLoc(), revisedWeightShapeType, weight,
74                                   rewriter.getI64ArrayAttr(revisedWeightShape))
75                               .getResult();
76 
77     // Perform a fully connected network over the reshaped input and weight.
78     llvm::SmallVector<int64_t, 2> fullyConnectedShape{
79         inputShape[0] * inputShape[1] * inputShape[2], weightShape[0]};
80     auto fullyConnectedShapeType = RankedTensorType::get(
81         fullyConnectedShape,
82         resultType.dyn_cast<ShapedType>().getElementType());
83 
84     Value fullyConnectedValue;
85     if (op.quantization_info()) {
86       fullyConnectedValue =
87           rewriter
88               .create<tosa::FullyConnectedOp>(
89                   op.getLoc(), fullyConnectedShapeType, reshapedInput,
90                   reshapedWeight, op.bias(), op.quantization_info().getValue())
91               .getResult();
92     } else {
93       fullyConnectedValue = rewriter
94                                 .create<tosa::FullyConnectedOp>(
95                                     op.getLoc(), fullyConnectedShapeType,
96                                     reshapedInput, reshapedWeight, op.bias())
97                                 .getResult();
98     }
99 
100     // Reshape output to [N, IH, IW, OC].
101     llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1],
102                                               inputShape[2], weightShape[0]};
103     rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
104         op, resultType, fullyConnectedValue,
105         rewriter.getI64ArrayAttr(outputShape));
106     return success();
107   }
108 };
109 
110 } // namespace
111 
112 void mlir::tosa::populateTosaDecomposeConv2D(MLIRContext *ctx,
113                                              RewritePatternSet &patterns) {
114   patterns.insert<Conv2DIsFullyConnected>(ctx);
115 }
116