1 //===- TosaDecomposeConv2D.cpp --------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Decompose TOSA Conv2D operation to a series of TOSA Ops specifically 10 // (1) Convert a 1x1 Convolution to a Reshape->FC->Reshape 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Tosa/IR/TosaOps.h" 15 #include "mlir/Dialect/Tosa/Transforms/Passes.h" 16 17 using namespace mlir; 18 using namespace mlir::tosa; 19 20 namespace { 21 22 struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> { 23 explicit Conv2DIsFullyConnected(MLIRContext *context) 24 : OpRewritePattern(context) {} 25 26 LogicalResult matchAndRewrite(tosa::Conv2DOp op, 27 PatternRewriter &rewriter) const override { 28 Value input = op.getInput(); 29 Value weight = op.getWeight(); 30 ShapedType inputType = input.getType().cast<ShapedType>(); 31 ShapedType weightType = weight.getType().cast<ShapedType>(); 32 ShapedType resultType = op.getType().cast<ShapedType>(); 33 34 auto numDynamic = 35 llvm::count_if(inputType.getShape(), ShapedType::isDynamic); 36 if (numDynamic > 1) 37 return rewriter.notifyMatchFailure( 38 op, "at most one dim in input may be dynamic"); 39 if (!weightType.hasRank()) 40 return rewriter.notifyMatchFailure(op, "unranked weight input"); 41 42 // Stride must be 1 for this optimization. 43 for (APInt stride : op.getStride().getAsValueRange<IntegerAttr>()) { 44 if (!stride.isOne()) 45 return failure(); 46 } 47 48 // Only works for a 1x1 kernel. 49 ArrayRef<int64_t> weightShape = weightType.getShape(); 50 if (weightShape[1] != 1 || weightShape[2] != 1) 51 return failure(); 52 53 // Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC]. 54 ArrayRef<int64_t> inputShape = inputType.getShape(); 55 int64_t combined = inputShape[0] * inputShape[1] * inputShape[2]; 56 if (combined < 0) 57 combined = ShapedType::kDynamicSize; 58 llvm::SmallVector<int64_t, 2> revisedInputShape{combined, inputShape[3]}; 59 auto revisedInputShapeType = 60 RankedTensorType::get(revisedInputShape, inputType.getElementType()); 61 auto reshapedInput = rewriter 62 .create<tosa::ReshapeOp>( 63 op.getLoc(), revisedInputShapeType, input, 64 rewriter.getI64ArrayAttr(revisedInputShape)) 65 .getResult(); 66 67 // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC]. 68 llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0], 69 weightShape[3]}; 70 auto revisedWeightShapeType = RankedTensorType::get( 71 revisedWeightShape, 72 weight.getType().dyn_cast<RankedTensorType>().getElementType()); 73 auto reshapedWeight = rewriter 74 .create<tosa::ReshapeOp>( 75 op.getLoc(), revisedWeightShapeType, weight, 76 rewriter.getI64ArrayAttr(revisedWeightShape)) 77 .getResult(); 78 79 // Perform a fully connected network over the reshaped input and weight. 80 llvm::SmallVector<int64_t, 2> fullyConnectedShape{combined, weightShape[0]}; 81 auto fullyConnectedShapeType = 82 RankedTensorType::get(fullyConnectedShape, resultType.getElementType()); 83 84 Value fullyConnectedValue; 85 if (op.getQuantizationInfo()) { 86 fullyConnectedValue = 87 rewriter 88 .create<tosa::FullyConnectedOp>( 89 op.getLoc(), fullyConnectedShapeType, reshapedInput, 90 reshapedWeight, op.getBias(), *op.getQuantizationInfo()) 91 .getResult(); 92 } else { 93 fullyConnectedValue = rewriter 94 .create<tosa::FullyConnectedOp>( 95 op.getLoc(), fullyConnectedShapeType, 96 reshapedInput, reshapedWeight, op.getBias()) 97 .getResult(); 98 } 99 100 // Reshape output to [N, IH, IW, OC]. 101 llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1], 102 inputShape[2], weightShape[0]}; 103 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>( 104 op, resultType, fullyConnectedValue, 105 rewriter.getI64ArrayAttr(outputShape)); 106 return success(); 107 } 108 }; 109 110 } // namespace 111 112 void mlir::tosa::populateTosaDecomposeConv2D(MLIRContext *ctx, 113 RewritePatternSet &patterns) { 114 patterns.add<Conv2DIsFullyConnected>(ctx); 115 } 116