1 //===- TosaDecomposeConv2D.cpp --------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Decompose TOSA Conv2D operation to a series of TOSA Ops specifically
10 // (1) Convert a 1x1 Convolution to a Reshape->FC->Reshape
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
15 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
16 
17 using namespace mlir;
18 using namespace mlir::tosa;
19 
20 namespace {
21 
22 struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
Conv2DIsFullyConnected__anon20f5aa440111::Conv2DIsFullyConnected23   explicit Conv2DIsFullyConnected(MLIRContext *context)
24       : OpRewritePattern(context) {}
25 
matchAndRewrite__anon20f5aa440111::Conv2DIsFullyConnected26   LogicalResult matchAndRewrite(tosa::Conv2DOp op,
27                                 PatternRewriter &rewriter) const override {
28     Value input = op.getInput();
29     Value weight = op.getWeight();
30     ShapedType inputType = input.getType().cast<ShapedType>();
31     ShapedType weightType = weight.getType().cast<ShapedType>();
32     ShapedType resultType = op.getType().cast<ShapedType>();
33 
34     auto numDynamic =
35         llvm::count_if(inputType.getShape(), ShapedType::isDynamic);
36     if (numDynamic > 1)
37       return rewriter.notifyMatchFailure(
38           op, "at most one dim in input may be dynamic");
39     if (!weightType.hasRank())
40       return rewriter.notifyMatchFailure(op, "unranked weight input");
41 
42     // Stride must be 1 for this optimization.
43     for (APInt stride : op.getStride().getAsValueRange<IntegerAttr>()) {
44       if (!stride.isOne())
45         return failure();
46     }
47 
48     // Only works for a 1x1 kernel.
49     ArrayRef<int64_t> weightShape = weightType.getShape();
50     if (weightShape[1] != 1 || weightShape[2] != 1)
51       return failure();
52 
53     // Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC].
54     ArrayRef<int64_t> inputShape = inputType.getShape();
55     int64_t combined = inputShape[0] * inputShape[1] * inputShape[2];
56     if (combined < 0)
57       combined = ShapedType::kDynamicSize;
58     llvm::SmallVector<int64_t, 2> revisedInputShape{combined, inputShape[3]};
59     auto revisedInputShapeType =
60         RankedTensorType::get(revisedInputShape, inputType.getElementType());
61     auto reshapedInput = rewriter
62                              .create<tosa::ReshapeOp>(
63                                  op.getLoc(), revisedInputShapeType, input,
64                                  rewriter.getI64ArrayAttr(revisedInputShape))
65                              .getResult();
66 
67     // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC].
68     llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0],
69                                                      weightShape[3]};
70     auto revisedWeightShapeType = RankedTensorType::get(
71         revisedWeightShape,
72         weight.getType().dyn_cast<RankedTensorType>().getElementType());
73     auto reshapedWeight = rewriter
74                               .create<tosa::ReshapeOp>(
75                                   op.getLoc(), revisedWeightShapeType, weight,
76                                   rewriter.getI64ArrayAttr(revisedWeightShape))
77                               .getResult();
78 
79     // Perform a fully connected network over the reshaped input and weight.
80     llvm::SmallVector<int64_t, 2> fullyConnectedShape{combined, weightShape[0]};
81     auto fullyConnectedShapeType =
82         RankedTensorType::get(fullyConnectedShape, resultType.getElementType());
83 
84     Value fullyConnectedValue;
85     if (op.getQuantizationInfo()) {
86       fullyConnectedValue =
87           rewriter
88               .create<tosa::FullyConnectedOp>(
89                   op.getLoc(), fullyConnectedShapeType, reshapedInput,
90                   reshapedWeight, op.getBias(), *op.getQuantizationInfo())
91               .getResult();
92     } else {
93       fullyConnectedValue = rewriter
94                                 .create<tosa::FullyConnectedOp>(
95                                     op.getLoc(), fullyConnectedShapeType,
96                                     reshapedInput, reshapedWeight, op.getBias())
97                                 .getResult();
98     }
99 
100     // Reshape output to [N, IH, IW, OC].
101     llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1],
102                                               inputShape[2], weightShape[0]};
103     rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
104         op, resultType, fullyConnectedValue,
105         rewriter.getI64ArrayAttr(outputShape));
106     return success();
107   }
108 };
109 
110 } // namespace
111 
populateTosaDecomposeConv2D(MLIRContext * ctx,RewritePatternSet & patterns)112 void mlir::tosa::populateTosaDecomposeConv2D(MLIRContext *ctx,
113                                              RewritePatternSet &patterns) {
114   patterns.add<Conv2DIsFullyConnected>(ctx);
115 }
116