1 //===- TosaDecomposeDepthwise.cpp -----------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Decompose TOSA Depthwise operation to a series of TOSA Ops specifically 10 // (1) Convert a 1x1 Depthwise to Reshape -> Mul -> Reshape -> Add 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Tosa/IR/TosaOps.h" 15 #include "mlir/Dialect/Tosa/Transforms/Passes.h" 16 #include "mlir/Pass/Pass.h" 17 18 using namespace mlir; 19 using namespace mlir::tosa; 20 21 namespace { 22 23 struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> { 24 explicit DepthwiseConv2DIsMul(MLIRContext *context) 25 : OpRewritePattern(context) {} 26 27 LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op, 28 PatternRewriter &rewriter) const override { 29 Value input = op.getInput(); 30 Value weight = op.getWeight(); 31 ShapedType inputType = input.getType().cast<ShapedType>(); 32 ShapedType weightType = weight.getType().cast<ShapedType>(); 33 ShapedType resultType = op.getOutput().getType().cast<ShapedType>(); 34 Type inputEType = inputType.getElementType(); 35 36 if (!(inputType.hasStaticShape() && weightType.hasStaticShape() && 37 resultType.hasStaticShape())) { 38 return failure(); 39 } 40 41 // Quantization information needs to still be performed. 42 if (op.getQuantizationInfo() || !inputEType.isa<FloatType>()) { 43 return failure(); 44 } 45 46 // Stride must be 1 for this optimization. 47 for (Attribute stride : op.getStride().getValue()) { 48 if (!stride.cast<IntegerAttr>().getValue().isOne()) { 49 return failure(); 50 } 51 } 52 53 // Only works for a 1x1 kernel. 54 ArrayRef<int64_t> weightShape = weightType.getShape(); 55 if (weightShape[0] != 1 || weightShape[1] != 1) { 56 return failure(); 57 } 58 59 // Reshape input to [N, H, W, C] -> [N, H, W, C, 1]. 60 ArrayRef<int64_t> inputShape = inputType.getShape(); 61 llvm::SmallVector<int64_t, 2> revisedInputShape{ 62 inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1}; 63 auto revisedInputShapeType = RankedTensorType::get( 64 revisedInputShape, 65 input.getType().dyn_cast<RankedTensorType>().getElementType()); 66 auto reshapedInput = rewriter 67 .create<tosa::ReshapeOp>( 68 op.getLoc(), revisedInputShapeType, input, 69 rewriter.getI64ArrayAttr(revisedInputShape)) 70 .getResult(); 71 72 // Reshape kernel to [KH, KW, C, M] -> [1, 1, 1, C, M]. 73 llvm::SmallVector<int64_t, 2> revisedWeightShape{1, 1, 1, weightShape[2], 74 weightShape[3]}; 75 auto revisedWeightShapeType = RankedTensorType::get( 76 revisedWeightShape, 77 weight.getType().dyn_cast<RankedTensorType>().getElementType()); 78 auto reshapedWeight = rewriter 79 .create<tosa::ReshapeOp>( 80 op.getLoc(), revisedWeightShapeType, weight, 81 rewriter.getI64ArrayAttr(revisedWeightShape)) 82 .getResult(); 83 84 // Perform an elementwise mul over the reshaped input and weight. 85 llvm::SmallVector<int64_t, 2> mulShape{inputShape[0], inputShape[1], 86 inputShape[2], inputShape[3], 87 weightShape[3]}; 88 auto mulShapeType = RankedTensorType::get( 89 mulShape, 90 weight.getType().dyn_cast<RankedTensorType>().getElementType()); 91 Value mulValue = 92 rewriter 93 .create<tosa::MulOp>(op.getLoc(), mulShapeType, reshapedInput, 94 reshapedWeight, /*shift=*/0) 95 .getResult(); 96 97 // Reshape output to [N, H, W, C * M]. 98 auto outputShape = op.getOutput().getType().cast<ShapedType>().getShape(); 99 auto outputShapeType = RankedTensorType::get( 100 outputShape, 101 input.getType().dyn_cast<RankedTensorType>().getElementType()); 102 auto outputValue = 103 rewriter.create<tosa::ReshapeOp>(op.getLoc(), outputShapeType, mulValue, 104 rewriter.getI64ArrayAttr(outputShape)); 105 106 // Add in the bias. 107 rewriter 108 .replaceOpWithNewOp<tosa::AddOp>(op, outputShapeType, outputValue, 109 op.getBias()) 110 .getResult(); 111 return success(); 112 } 113 }; 114 115 } // namespace 116 117 void mlir::tosa::populateTosaDecomposeDepthwise(MLIRContext *ctx, 118 RewritePatternSet &patterns) { 119 patterns.add<DepthwiseConv2DIsMul>(ctx); 120 } 121