1 //===- TosaDecomposeConv2D.cpp --------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Decompose TOSA Conv2D operation to a series of TOSA Ops specifically 10 // (1) Convert a 1x1 Convolution to a Reshape->FC->Reshape 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Tosa/IR/TosaOps.h" 15 #include "mlir/Dialect/Tosa/Transforms/Passes.h" 16 17 using namespace mlir; 18 using namespace mlir::tosa; 19 20 namespace { 21 22 struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> { 23 explicit Conv2DIsFullyConnected(MLIRContext *context) 24 : OpRewritePattern(context) {} 25 26 LogicalResult matchAndRewrite(tosa::Conv2DOp op, 27 PatternRewriter &rewriter) const override { 28 Value input = op.input(); 29 Value weight = op.weight(); 30 ShapedType inputType = input.getType().cast<ShapedType>(); 31 ShapedType weightType = weight.getType().cast<ShapedType>(); 32 ShapedType resultType = op.getType().cast<ShapedType>(); 33 34 auto numDynamic = llvm::count_if(inputType.getShape(), [](int64_t d) { 35 return ShapedType::isDynamic(d); 36 }); 37 if (numDynamic > 1) 38 return rewriter.notifyMatchFailure( 39 op, "at most one dim in input may be dynamic"); 40 if (!weightType.hasRank()) 41 return rewriter.notifyMatchFailure(op, "unranked weight input"); 42 43 // Stride must be 1 for this optimization. 44 for (APInt stride : op.stride().getAsValueRange<IntegerAttr>()) { 45 if (!stride.isOne()) 46 return failure(); 47 } 48 49 // Only works for a 1x1 kernel. 50 ArrayRef<int64_t> weightShape = weightType.getShape(); 51 if (weightShape[1] != 1 || weightShape[2] != 1) 52 return failure(); 53 54 // Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC]. 55 ArrayRef<int64_t> inputShape = inputType.getShape(); 56 int64_t combined = inputShape[0] * inputShape[1] * inputShape[2]; 57 if (combined < 0) 58 combined = ShapedType::kDynamicSize; 59 llvm::SmallVector<int64_t, 2> revisedInputShape{combined, inputShape[3]}; 60 auto revisedInputShapeType = 61 RankedTensorType::get(revisedInputShape, inputType.getElementType()); 62 auto reshapedInput = rewriter 63 .create<tosa::ReshapeOp>( 64 op.getLoc(), revisedInputShapeType, input, 65 rewriter.getI64ArrayAttr(revisedInputShape)) 66 .getResult(); 67 68 // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC]. 69 llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0], 70 weightShape[3]}; 71 auto revisedWeightShapeType = RankedTensorType::get( 72 revisedWeightShape, 73 weight.getType().dyn_cast<RankedTensorType>().getElementType()); 74 auto reshapedWeight = rewriter 75 .create<tosa::ReshapeOp>( 76 op.getLoc(), revisedWeightShapeType, weight, 77 rewriter.getI64ArrayAttr(revisedWeightShape)) 78 .getResult(); 79 80 // Perform a fully connected network over the reshaped input and weight. 81 llvm::SmallVector<int64_t, 2> fullyConnectedShape{combined, weightShape[0]}; 82 auto fullyConnectedShapeType = 83 RankedTensorType::get(fullyConnectedShape, resultType.getElementType()); 84 85 Value fullyConnectedValue; 86 if (op.quantization_info()) { 87 fullyConnectedValue = 88 rewriter 89 .create<tosa::FullyConnectedOp>( 90 op.getLoc(), fullyConnectedShapeType, reshapedInput, 91 reshapedWeight, op.bias(), *op.quantization_info()) 92 .getResult(); 93 } else { 94 fullyConnectedValue = rewriter 95 .create<tosa::FullyConnectedOp>( 96 op.getLoc(), fullyConnectedShapeType, 97 reshapedInput, reshapedWeight, op.bias()) 98 .getResult(); 99 } 100 101 // Reshape output to [N, IH, IW, OC]. 102 llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1], 103 inputShape[2], weightShape[0]}; 104 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>( 105 op, resultType, fullyConnectedValue, 106 rewriter.getI64ArrayAttr(outputShape)); 107 return success(); 108 } 109 }; 110 111 } // namespace 112 113 void mlir::tosa::populateTosaDecomposeConv2D(MLIRContext *ctx, 114 RewritePatternSet &patterns) { 115 patterns.add<Conv2DIsFullyConnected>(ctx); 116 } 117