1 //===- TosaDecomposeConv2D.cpp ------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Decompose TOSA Conv2D operation to a series of TOSA Ops specifically 10 // (1) Convert a 1x1 Convolution to a Reshape->FC->Reshape 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Tosa/IR/TosaOps.h" 15 #include "mlir/Dialect/Tosa/Transforms/Passes.h" 16 #include "mlir/Pass/Pass.h" 17 18 using namespace mlir; 19 using namespace mlir::tosa; 20 21 namespace { 22 23 struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> { 24 explicit Conv2DIsFullyConnected(MLIRContext *context) 25 : OpRewritePattern(context) {} 26 27 LogicalResult matchAndRewrite(tosa::Conv2DOp op, 28 PatternRewriter &rewriter) const override { 29 Value input = op.input(); 30 Value weight = op.weight(); 31 ShapedType inputType = input.getType().cast<ShapedType>(); 32 ShapedType weightType = weight.getType().cast<ShapedType>(); 33 ShapedType resultType = op.getType().cast<ShapedType>(); 34 35 if (!inputType.hasStaticShape() || !weightType.hasRank()) { 36 return failure(); 37 } 38 39 // Stride must be 1 for this optimization. 40 for (Attribute stride : op.stride().getValue()) { 41 if (!stride.cast<IntegerAttr>().getValue().isOne()) { 42 return failure(); 43 } 44 } 45 46 // Only works for a 1x1 kernel. 47 ArrayRef<int64_t> weightShape = weightType.getShape(); 48 if (weightShape[1] != 1 || weightShape[2] != 1) { 49 return failure(); 50 } 51 52 // Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC]. 53 ArrayRef<int64_t> inputShape = inputType.getShape(); 54 llvm::SmallVector<int64_t, 2> revisedInputShape{ 55 inputShape[0] * inputShape[1] * inputShape[2], inputShape[3]}; 56 auto revisedInputShapeType = RankedTensorType::get( 57 revisedInputShape, 58 input.getType().dyn_cast<RankedTensorType>().getElementType()); 59 auto reshapedInput = rewriter 60 .create<tosa::ReshapeOp>( 61 op.getLoc(), revisedInputShapeType, input, 62 rewriter.getI64ArrayAttr(revisedInputShape)) 63 .getResult(); 64 65 // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC]. 66 llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0], 67 weightShape[3]}; 68 auto revisedWeightShapeType = RankedTensorType::get( 69 revisedWeightShape, 70 weight.getType().dyn_cast<RankedTensorType>().getElementType()); 71 auto reshapedWeight = rewriter 72 .create<tosa::ReshapeOp>( 73 op.getLoc(), revisedWeightShapeType, weight, 74 rewriter.getI64ArrayAttr(revisedWeightShape)) 75 .getResult(); 76 77 // Perform a fully connected network over the reshaped input and weight. 78 llvm::SmallVector<int64_t, 2> fullyConnectedShape{ 79 inputShape[0] * inputShape[1] * inputShape[2], weightShape[0]}; 80 auto fullyConnectedShapeType = RankedTensorType::get( 81 fullyConnectedShape, 82 resultType.dyn_cast<ShapedType>().getElementType()); 83 84 Value fullyConnectedValue; 85 if (op.quantization_info()) { 86 fullyConnectedValue = 87 rewriter 88 .create<tosa::FullyConnectedOp>( 89 op.getLoc(), fullyConnectedShapeType, reshapedInput, 90 reshapedWeight, op.bias(), op.quantization_info().getValue()) 91 .getResult(); 92 } else { 93 fullyConnectedValue = rewriter 94 .create<tosa::FullyConnectedOp>( 95 op.getLoc(), fullyConnectedShapeType, 96 reshapedInput, reshapedWeight, op.bias()) 97 .getResult(); 98 } 99 100 // Reshape output to [N, IH, IW, OC]. 101 llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1], 102 inputShape[2], weightShape[0]}; 103 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>( 104 op, resultType, fullyConnectedValue, 105 rewriter.getI64ArrayAttr(outputShape)); 106 return success(); 107 } 108 }; 109 110 } // namespace 111 112 void mlir::tosa::populateTosaDecomposeConv2D(MLIRContext *ctx, 113 RewritePatternSet &patterns) { 114 patterns.insert<Conv2DIsFullyConnected>(ctx); 115 } 116