1 //===- TosaDecomposeConv2D.cpp --------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Decompose TOSA Conv2D operation to a series of TOSA Ops specifically
10 // (1) Convert a 1x1 Convolution to a Reshape->FC->Reshape
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
15 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
16
17 using namespace mlir;
18 using namespace mlir::tosa;
19
20 namespace {
21
22 struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
Conv2DIsFullyConnected__anon20f5aa440111::Conv2DIsFullyConnected23 explicit Conv2DIsFullyConnected(MLIRContext *context)
24 : OpRewritePattern(context) {}
25
matchAndRewrite__anon20f5aa440111::Conv2DIsFullyConnected26 LogicalResult matchAndRewrite(tosa::Conv2DOp op,
27 PatternRewriter &rewriter) const override {
28 Value input = op.getInput();
29 Value weight = op.getWeight();
30 ShapedType inputType = input.getType().cast<ShapedType>();
31 ShapedType weightType = weight.getType().cast<ShapedType>();
32 ShapedType resultType = op.getType().cast<ShapedType>();
33
34 auto numDynamic =
35 llvm::count_if(inputType.getShape(), ShapedType::isDynamic);
36 if (numDynamic > 1)
37 return rewriter.notifyMatchFailure(
38 op, "at most one dim in input may be dynamic");
39 if (!weightType.hasRank())
40 return rewriter.notifyMatchFailure(op, "unranked weight input");
41
42 // Stride must be 1 for this optimization.
43 for (APInt stride : op.getStride().getAsValueRange<IntegerAttr>()) {
44 if (!stride.isOne())
45 return failure();
46 }
47
48 // Only works for a 1x1 kernel.
49 ArrayRef<int64_t> weightShape = weightType.getShape();
50 if (weightShape[1] != 1 || weightShape[2] != 1)
51 return failure();
52
53 // Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC].
54 ArrayRef<int64_t> inputShape = inputType.getShape();
55 int64_t combined = inputShape[0] * inputShape[1] * inputShape[2];
56 if (combined < 0)
57 combined = ShapedType::kDynamicSize;
58 llvm::SmallVector<int64_t, 2> revisedInputShape{combined, inputShape[3]};
59 auto revisedInputShapeType =
60 RankedTensorType::get(revisedInputShape, inputType.getElementType());
61 auto reshapedInput = rewriter
62 .create<tosa::ReshapeOp>(
63 op.getLoc(), revisedInputShapeType, input,
64 rewriter.getI64ArrayAttr(revisedInputShape))
65 .getResult();
66
67 // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC].
68 llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0],
69 weightShape[3]};
70 auto revisedWeightShapeType = RankedTensorType::get(
71 revisedWeightShape,
72 weight.getType().dyn_cast<RankedTensorType>().getElementType());
73 auto reshapedWeight = rewriter
74 .create<tosa::ReshapeOp>(
75 op.getLoc(), revisedWeightShapeType, weight,
76 rewriter.getI64ArrayAttr(revisedWeightShape))
77 .getResult();
78
79 // Perform a fully connected network over the reshaped input and weight.
80 llvm::SmallVector<int64_t, 2> fullyConnectedShape{combined, weightShape[0]};
81 auto fullyConnectedShapeType =
82 RankedTensorType::get(fullyConnectedShape, resultType.getElementType());
83
84 Value fullyConnectedValue;
85 if (op.getQuantizationInfo()) {
86 fullyConnectedValue =
87 rewriter
88 .create<tosa::FullyConnectedOp>(
89 op.getLoc(), fullyConnectedShapeType, reshapedInput,
90 reshapedWeight, op.getBias(), *op.getQuantizationInfo())
91 .getResult();
92 } else {
93 fullyConnectedValue = rewriter
94 .create<tosa::FullyConnectedOp>(
95 op.getLoc(), fullyConnectedShapeType,
96 reshapedInput, reshapedWeight, op.getBias())
97 .getResult();
98 }
99
100 // Reshape output to [N, IH, IW, OC].
101 llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1],
102 inputShape[2], weightShape[0]};
103 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
104 op, resultType, fullyConnectedValue,
105 rewriter.getI64ArrayAttr(outputShape));
106 return success();
107 }
108 };
109
110 } // namespace
111
populateTosaDecomposeConv2D(MLIRContext * ctx,RewritePatternSet & patterns)112 void mlir::tosa::populateTosaDecomposeConv2D(MLIRContext *ctx,
113 RewritePatternSet &patterns) {
114 patterns.add<Conv2DIsFullyConnected>(ctx);
115 }
116