1 //===- TosaDecomposeDepthwise.cpp 2 //------------------------------------------===// 3 // 4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 5 // See https://llvm.org/LICENSE.txt for license information. 6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // Decompose TOSA Depthwise operation to a series of TOSA Ops specifically 11 // (1) Convert a 1x1 Depthwise to Reshape -> Mul -> Reshape -> Add 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "mlir/Dialect/Tosa/IR/TosaOps.h" 16 #include "mlir/Dialect/Tosa/Transforms/Passes.h" 17 #include "mlir/Pass/Pass.h" 18 19 using namespace mlir; 20 using namespace mlir::tosa; 21 22 namespace { 23 24 struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> { 25 explicit DepthwiseConv2DIsMul(MLIRContext *context) 26 : OpRewritePattern(context) {} 27 28 LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op, 29 PatternRewriter &rewriter) const override { 30 Value input = op.input(); 31 Value weight = op.weight(); 32 ShapedType inputType = input.getType().cast<ShapedType>(); 33 ShapedType weightType = weight.getType().cast<ShapedType>(); 34 ShapedType resultType = op.output().getType().cast<ShapedType>(); 35 Type inputEType = inputType.getElementType(); 36 37 if (!(inputType.hasStaticShape() && weightType.hasStaticShape() && 38 resultType.hasStaticShape())) { 39 return failure(); 40 } 41 42 // Quantization information needs to still be performed. 43 if (op.quantization_info() || !inputEType.isa<FloatType>()) { 44 return failure(); 45 } 46 47 // Stride must be 1 for this optimization. 48 for (Attribute stride : op.stride().getValue()) { 49 if (!stride.cast<IntegerAttr>().getValue().isOne()) { 50 return failure(); 51 } 52 } 53 54 // Only works for a 1x1 kernel. 55 ArrayRef<int64_t> weightShape = weightType.getShape(); 56 if (weightShape[0] != 1 || weightShape[1] != 1) { 57 return failure(); 58 } 59 60 // Reshape input to [N, H, W, C] -> [N, H, W, C, 1]. 61 ArrayRef<int64_t> inputShape = inputType.getShape(); 62 llvm::SmallVector<int64_t, 2> revisedInputShape{ 63 inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1}; 64 auto revisedInputShapeType = RankedTensorType::get( 65 revisedInputShape, 66 input.getType().dyn_cast<RankedTensorType>().getElementType()); 67 auto reshapedInput = rewriter 68 .create<tosa::ReshapeOp>( 69 op.getLoc(), revisedInputShapeType, input, 70 rewriter.getI64ArrayAttr(revisedInputShape)) 71 .getResult(); 72 73 // Reshape kernel to [KH, KW, C, M] -> [1, 1, 1, C, M]. 74 llvm::SmallVector<int64_t, 2> revisedWeightShape{1, 1, 1, weightShape[2], 75 weightShape[3]}; 76 auto revisedWeightShapeType = RankedTensorType::get( 77 revisedWeightShape, 78 weight.getType().dyn_cast<RankedTensorType>().getElementType()); 79 auto reshapedWeight = rewriter 80 .create<tosa::ReshapeOp>( 81 op.getLoc(), revisedWeightShapeType, weight, 82 rewriter.getI64ArrayAttr(revisedWeightShape)) 83 .getResult(); 84 85 // Perform an elementwise mul over the reshaped input and weight. 86 llvm::SmallVector<int64_t, 2> mulShape{inputShape[0], inputShape[1], 87 inputShape[2], inputShape[3], 88 weightShape[3]}; 89 auto mulShapeType = RankedTensorType::get( 90 mulShape, 91 weight.getType().dyn_cast<RankedTensorType>().getElementType()); 92 Value mulValue = 93 rewriter 94 .create<tosa::MulOp>(op.getLoc(), mulShapeType, reshapedInput, 95 reshapedWeight, /*shift=*/0) 96 .getResult(); 97 98 // Reshape output to [N, H, W, C * M]. 99 auto outputShape = op.output().getType().cast<ShapedType>().getShape(); 100 auto outputShapeType = RankedTensorType::get( 101 outputShape, 102 input.getType().dyn_cast<RankedTensorType>().getElementType()); 103 auto outputValue = 104 rewriter.create<tosa::ReshapeOp>(op.getLoc(), outputShapeType, mulValue, 105 rewriter.getI64ArrayAttr(outputShape)); 106 107 // Add in the bias. 108 rewriter 109 .replaceOpWithNewOp<tosa::AddOp>(op, outputShapeType, outputValue, 110 op.bias()) 111 .getResult(); 112 return success(); 113 } 114 }; 115 116 } // namespace 117 118 void mlir::tosa::populateTosaDecomposeDepthwise(MLIRContext *ctx, 119 RewritePatternSet &patterns) { 120 patterns.insert<DepthwiseConv2DIsMul>(ctx); 121 } 122