1 //===- TosaDecomposeConv2D.cpp --------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Decompose TOSA Conv2D operation to a series of TOSA Ops specifically
10 // (1) Convert a 1x1 Convolution to a Reshape->FC->Reshape
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
15 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
16 
17 using namespace mlir;
18 using namespace mlir::tosa;
19 
20 namespace {
21 
22 struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
23   explicit Conv2DIsFullyConnected(MLIRContext *context)
24       : OpRewritePattern(context) {}
25 
26   LogicalResult matchAndRewrite(tosa::Conv2DOp op,
27                                 PatternRewriter &rewriter) const override {
28     Value input = op.input();
29     Value weight = op.weight();
30     ShapedType inputType = input.getType().cast<ShapedType>();
31     ShapedType weightType = weight.getType().cast<ShapedType>();
32     ShapedType resultType = op.getType().cast<ShapedType>();
33 
34     auto numDynamic = llvm::count_if(inputType.getShape(), [](int64_t d) {
35       return ShapedType::isDynamic(d);
36     });
37     if (numDynamic > 1)
38       return rewriter.notifyMatchFailure(
39           op, "at most one dim in input may be dynamic");
40     if (!weightType.hasRank())
41       return rewriter.notifyMatchFailure(op, "unranked weight input");
42 
43     // Stride must be 1 for this optimization.
44     for (APInt stride : op.stride().getAsValueRange<IntegerAttr>()) {
45       if (!stride.isOne())
46         return failure();
47     }
48 
49     // Only works for a 1x1 kernel.
50     ArrayRef<int64_t> weightShape = weightType.getShape();
51     if (weightShape[1] != 1 || weightShape[2] != 1)
52       return failure();
53 
54     // Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC].
55     ArrayRef<int64_t> inputShape = inputType.getShape();
56     int64_t combined = inputShape[0] * inputShape[1] * inputShape[2];
57     if (combined < 0)
58       combined = ShapedType::kDynamicSize;
59     llvm::SmallVector<int64_t, 2> revisedInputShape{combined, inputShape[3]};
60     auto revisedInputShapeType =
61         RankedTensorType::get(revisedInputShape, inputType.getElementType());
62     auto reshapedInput = rewriter
63                              .create<tosa::ReshapeOp>(
64                                  op.getLoc(), revisedInputShapeType, input,
65                                  rewriter.getI64ArrayAttr(revisedInputShape))
66                              .getResult();
67 
68     // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC].
69     llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0],
70                                                      weightShape[3]};
71     auto revisedWeightShapeType = RankedTensorType::get(
72         revisedWeightShape,
73         weight.getType().dyn_cast<RankedTensorType>().getElementType());
74     auto reshapedWeight = rewriter
75                               .create<tosa::ReshapeOp>(
76                                   op.getLoc(), revisedWeightShapeType, weight,
77                                   rewriter.getI64ArrayAttr(revisedWeightShape))
78                               .getResult();
79 
80     // Perform a fully connected network over the reshaped input and weight.
81     llvm::SmallVector<int64_t, 2> fullyConnectedShape{combined, weightShape[0]};
82     auto fullyConnectedShapeType =
83         RankedTensorType::get(fullyConnectedShape, resultType.getElementType());
84 
85     Value fullyConnectedValue;
86     if (op.quantization_info()) {
87       fullyConnectedValue =
88           rewriter
89               .create<tosa::FullyConnectedOp>(
90                   op.getLoc(), fullyConnectedShapeType, reshapedInput,
91                   reshapedWeight, op.bias(), *op.quantization_info())
92               .getResult();
93     } else {
94       fullyConnectedValue = rewriter
95                                 .create<tosa::FullyConnectedOp>(
96                                     op.getLoc(), fullyConnectedShapeType,
97                                     reshapedInput, reshapedWeight, op.bias())
98                                 .getResult();
99     }
100 
101     // Reshape output to [N, IH, IW, OC].
102     llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1],
103                                               inputShape[2], weightShape[0]};
104     rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
105         op, resultType, fullyConnectedValue,
106         rewriter.getI64ArrayAttr(outputShape));
107     return success();
108   }
109 };
110 
111 } // namespace
112 
113 void mlir::tosa::populateTosaDecomposeConv2D(MLIRContext *ctx,
114                                              RewritePatternSet &patterns) {
115   patterns.add<Conv2DIsFullyConnected>(ctx);
116 }
117