//===- TosaDecomposeDepthwise.cpp -----------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Decompose TOSA Depthwise operation to a series of TOSA Ops specifically // (1) Convert a 1x1 Depthwise to Reshape -> Mul -> Reshape -> Add // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" #include "mlir/Pass/Pass.h" using namespace mlir; using namespace mlir::tosa; namespace { struct DepthwiseConv2DIsMul : public OpRewritePattern { explicit DepthwiseConv2DIsMul(MLIRContext *context) : OpRewritePattern(context) {} LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op, PatternRewriter &rewriter) const override { Value input = op.getInput(); Value weight = op.getWeight(); ShapedType inputType = input.getType().cast(); ShapedType weightType = weight.getType().cast(); ShapedType resultType = op.getOutput().getType().cast(); Type inputEType = inputType.getElementType(); if (!(inputType.hasStaticShape() && weightType.hasStaticShape() && resultType.hasStaticShape())) { return failure(); } // Quantization information needs to still be performed. if (op.getQuantizationInfo() || !inputEType.isa()) { return failure(); } // Stride must be 1 for this optimization. for (Attribute stride : op.getStride().getValue()) { if (!stride.cast().getValue().isOne()) { return failure(); } } // Only works for a 1x1 kernel. ArrayRef weightShape = weightType.getShape(); if (weightShape[0] != 1 || weightShape[1] != 1) { return failure(); } // Reshape input to [N, H, W, C] -> [N, H, W, C, 1]. ArrayRef inputShape = inputType.getShape(); llvm::SmallVector revisedInputShape{ inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1}; auto revisedInputShapeType = RankedTensorType::get( revisedInputShape, input.getType().dyn_cast().getElementType()); auto reshapedInput = rewriter .create( op.getLoc(), revisedInputShapeType, input, rewriter.getI64ArrayAttr(revisedInputShape)) .getResult(); // Reshape kernel to [KH, KW, C, M] -> [1, 1, 1, C, M]. llvm::SmallVector revisedWeightShape{1, 1, 1, weightShape[2], weightShape[3]}; auto revisedWeightShapeType = RankedTensorType::get( revisedWeightShape, weight.getType().dyn_cast().getElementType()); auto reshapedWeight = rewriter .create( op.getLoc(), revisedWeightShapeType, weight, rewriter.getI64ArrayAttr(revisedWeightShape)) .getResult(); // Perform an elementwise mul over the reshaped input and weight. llvm::SmallVector mulShape{inputShape[0], inputShape[1], inputShape[2], inputShape[3], weightShape[3]}; auto mulShapeType = RankedTensorType::get( mulShape, weight.getType().dyn_cast().getElementType()); Value mulValue = rewriter .create(op.getLoc(), mulShapeType, reshapedInput, reshapedWeight, /*shift=*/0) .getResult(); // Reshape output to [N, H, W, C * M]. auto outputShape = op.getOutput().getType().cast().getShape(); auto outputShapeType = RankedTensorType::get( outputShape, input.getType().dyn_cast().getElementType()); auto outputValue = rewriter.create(op.getLoc(), outputShapeType, mulValue, rewriter.getI64ArrayAttr(outputShape)); // Add in the bias. rewriter .replaceOpWithNewOp(op, outputShapeType, outputValue, op.getBias()) .getResult(); return success(); } }; } // namespace void mlir::tosa::populateTosaDecomposeDepthwise(MLIRContext *ctx, RewritePatternSet &patterns) { patterns.add(ctx); }