14142932aSMaheshRavishankar //===- NamedOpConversions.cpp - Implements conversions between named ops --===//
24142932aSMaheshRavishankar //
34142932aSMaheshRavishankar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44142932aSMaheshRavishankar // See https://llvm.org/LICENSE.txt for license information.
54142932aSMaheshRavishankar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
64142932aSMaheshRavishankar //
74142932aSMaheshRavishankar //===----------------------------------------------------------------------===//
84142932aSMaheshRavishankar //
94142932aSMaheshRavishankar // This file implements conversions between named ops that can be seens as
104142932aSMaheshRavishankar // canonicalizations of named ops.
114142932aSMaheshRavishankar //
124142932aSMaheshRavishankar //===----------------------------------------------------------------------===//
134142932aSMaheshRavishankar #include "PassDetail.h"
144142932aSMaheshRavishankar #include "mlir/Dialect/Linalg/IR/Linalg.h"
154142932aSMaheshRavishankar #include "mlir/Dialect/Linalg/Passes.h"
164142932aSMaheshRavishankar #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
174142932aSMaheshRavishankar #include "mlir/IR/PatternMatch.h"
184142932aSMaheshRavishankar #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
194142932aSMaheshRavishankar
204142932aSMaheshRavishankar #include "llvm/ADT/SmallVector.h"
214142932aSMaheshRavishankar
224142932aSMaheshRavishankar using namespace mlir;
234142932aSMaheshRavishankar using namespace mlir::linalg;
244142932aSMaheshRavishankar
getIndicesVector(int start,int end)254142932aSMaheshRavishankar static llvm::SmallVector<int64_t> getIndicesVector(int start, int end) {
264142932aSMaheshRavishankar return llvm::to_vector<2>(llvm::seq<int64_t>(start, end));
274142932aSMaheshRavishankar }
284142932aSMaheshRavishankar
294142932aSMaheshRavishankar static LogicalResult
matchAndReplaceDepthwiseConv(Operation * operation,Value input,Value kernel,Value iZp,Value kZp,Value init,Attribute stride,Attribute dilation,PatternRewriter & rewriter)304142932aSMaheshRavishankar matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel,
314142932aSMaheshRavishankar Value iZp, Value kZp, Value init, Attribute stride,
324142932aSMaheshRavishankar Attribute dilation, PatternRewriter &rewriter) {
334142932aSMaheshRavishankar Location loc = operation->getLoc();
344142932aSMaheshRavishankar auto linalgOp = dyn_cast<LinalgOp>(operation);
354142932aSMaheshRavishankar // Exit out on the memref version of this operation.
364142932aSMaheshRavishankar if (!linalgOp || !linalgOp.hasTensorSemantics())
374142932aSMaheshRavishankar return failure();
384142932aSMaheshRavishankar
394142932aSMaheshRavishankar auto result = operation->getResult(0);
404142932aSMaheshRavishankar
414142932aSMaheshRavishankar auto kernelTy = kernel.getType().dyn_cast<RankedTensorType>();
424142932aSMaheshRavishankar auto initTy = init.getType().dyn_cast<RankedTensorType>();
434142932aSMaheshRavishankar auto resultTy = result.getType().template dyn_cast<RankedTensorType>();
444142932aSMaheshRavishankar if (!kernelTy || !initTy || !resultTy)
454142932aSMaheshRavishankar return failure();
464142932aSMaheshRavishankar
474142932aSMaheshRavishankar if (kernelTy.getDimSize(3) != 1)
484142932aSMaheshRavishankar return failure();
494142932aSMaheshRavishankar
504142932aSMaheshRavishankar // Collapse kernel dims.
514142932aSMaheshRavishankar SmallVector<ReassociationIndices, 4> collapsedKernelDims = {
524142932aSMaheshRavishankar getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 4)};
534142932aSMaheshRavishankar auto newKernelTy = RankedTensorType::get(
544142932aSMaheshRavishankar {kernelTy.getDimSize(0), kernelTy.getDimSize(1), kernelTy.getDimSize(2)},
554142932aSMaheshRavishankar kernelTy.getElementType());
564142932aSMaheshRavishankar auto collapsedKernel = rewriter.create<tensor::CollapseShapeOp>(
574142932aSMaheshRavishankar loc, newKernelTy, kernel, collapsedKernelDims);
584142932aSMaheshRavishankar
594142932aSMaheshRavishankar // Collapse init dims.
604142932aSMaheshRavishankar SmallVector<ReassociationIndices, 4> collapsedInitDims = {
614142932aSMaheshRavishankar getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 3),
624142932aSMaheshRavishankar getIndicesVector(3, 5)};
634142932aSMaheshRavishankar auto newInitTy =
644142932aSMaheshRavishankar RankedTensorType::get({initTy.getDimSize(0), initTy.getDimSize(1),
654142932aSMaheshRavishankar initTy.getDimSize(2), initTy.getDimSize(3)},
664142932aSMaheshRavishankar initTy.getElementType());
674142932aSMaheshRavishankar auto collapsedInit = rewriter.create<tensor::CollapseShapeOp>(
684142932aSMaheshRavishankar loc, newInitTy, init, collapsedInitDims);
694142932aSMaheshRavishankar
704142932aSMaheshRavishankar Value newConv;
714142932aSMaheshRavishankar if (isa<DepthwiseConv2DNhwcHwcmOp>(operation)) {
724142932aSMaheshRavishankar newConv = rewriter
734142932aSMaheshRavishankar .create<DepthwiseConv2DNhwcHwcOp>(
744142932aSMaheshRavishankar loc, newInitTy, ValueRange{input, collapsedKernel},
754142932aSMaheshRavishankar ValueRange{collapsedInit}, stride, dilation)
764142932aSMaheshRavishankar .getResult(0);
774142932aSMaheshRavishankar } else if (isa<DepthwiseConv2DNhwcHwcmQOp>(operation)) {
784142932aSMaheshRavishankar newConv =
794142932aSMaheshRavishankar rewriter
804142932aSMaheshRavishankar .create<DepthwiseConv2DNhwcHwcQOp>(
814142932aSMaheshRavishankar loc, newInitTy, ValueRange{input, collapsedKernel, iZp, kZp},
824142932aSMaheshRavishankar ValueRange{collapsedInit}, stride, dilation)
834142932aSMaheshRavishankar .getResult(0);
844142932aSMaheshRavishankar }
854142932aSMaheshRavishankar
864142932aSMaheshRavishankar if (!newConv)
874142932aSMaheshRavishankar return failure();
884142932aSMaheshRavishankar
894142932aSMaheshRavishankar // Expand dimensions back out to
904142932aSMaheshRavishankar rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
914142932aSMaheshRavishankar operation, resultTy, newConv, collapsedInitDims);
924142932aSMaheshRavishankar return success();
934142932aSMaheshRavishankar }
944142932aSMaheshRavishankar
954142932aSMaheshRavishankar namespace {
964142932aSMaheshRavishankar struct SimplifyDepthwiseConvOp
974142932aSMaheshRavishankar : public OpRewritePattern<DepthwiseConv2DNhwcHwcmOp> {
984142932aSMaheshRavishankar using OpRewritePattern<DepthwiseConv2DNhwcHwcmOp>::OpRewritePattern;
994142932aSMaheshRavishankar
matchAndRewrite__anon52f9791f0111::SimplifyDepthwiseConvOp1004142932aSMaheshRavishankar LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmOp op,
1014142932aSMaheshRavishankar PatternRewriter &rewriter) const override {
1024142932aSMaheshRavishankar Operation *operation = op.getOperation();
1034142932aSMaheshRavishankar Value input = op.getInputOperand(0)->get();
1044142932aSMaheshRavishankar Value kernel = op.getInputOperand(1)->get();
1054142932aSMaheshRavishankar Value init = op.getOutputOperand(0)->get();
1064142932aSMaheshRavishankar
1074142932aSMaheshRavishankar auto stride = op.strides();
1084142932aSMaheshRavishankar auto dilation = op.dilations();
1094142932aSMaheshRavishankar
1104142932aSMaheshRavishankar return matchAndReplaceDepthwiseConv(operation, input, kernel, nullptr,
1114142932aSMaheshRavishankar nullptr, init, stride, dilation,
1124142932aSMaheshRavishankar rewriter);
1134142932aSMaheshRavishankar }
1144142932aSMaheshRavishankar };
1154142932aSMaheshRavishankar
1164142932aSMaheshRavishankar struct SimplifyDepthwiseConvQOp
1174142932aSMaheshRavishankar : public OpRewritePattern<DepthwiseConv2DNhwcHwcmQOp> {
1184142932aSMaheshRavishankar using OpRewritePattern<DepthwiseConv2DNhwcHwcmQOp>::OpRewritePattern;
1194142932aSMaheshRavishankar
matchAndRewrite__anon52f9791f0111::SimplifyDepthwiseConvQOp1204142932aSMaheshRavishankar LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmQOp op,
1214142932aSMaheshRavishankar PatternRewriter &rewriter) const override {
1224142932aSMaheshRavishankar Operation *operation = op.getOperation();
1234142932aSMaheshRavishankar Value input = op.getInputOperand(0)->get();
1244142932aSMaheshRavishankar Value kernel = op.getInputOperand(1)->get();
1254142932aSMaheshRavishankar Value iZp = op.getInputOperand(2)->get();
1264142932aSMaheshRavishankar Value kZp = op.getInputOperand(3)->get();
1274142932aSMaheshRavishankar Value init = op.getOutputOperand(0)->get();
1284142932aSMaheshRavishankar
1294142932aSMaheshRavishankar auto stride = op.strides();
1304142932aSMaheshRavishankar auto dilation = op.dilations();
1314142932aSMaheshRavishankar
1324142932aSMaheshRavishankar return matchAndReplaceDepthwiseConv(operation, input, kernel, iZp, kZp,
1334142932aSMaheshRavishankar init, stride, dilation, rewriter);
1344142932aSMaheshRavishankar }
1354142932aSMaheshRavishankar };
1364142932aSMaheshRavishankar
1374142932aSMaheshRavishankar struct LinalgNamedOpConversionPass
1384142932aSMaheshRavishankar : public LinalgNamedOpConversionBase<LinalgNamedOpConversionPass> {
1394142932aSMaheshRavishankar LinalgNamedOpConversionPass() = default;
140*9a7d111fSNicolas Vasilache LinalgNamedOpConversionPass(const LinalgNamedOpConversionPass &) = default;
1414142932aSMaheshRavishankar
runOnOperation__anon52f9791f0111::LinalgNamedOpConversionPass1424142932aSMaheshRavishankar void runOnOperation() override {
1434142932aSMaheshRavishankar Operation *op = getOperation();
1444142932aSMaheshRavishankar RewritePatternSet patterns(op->getContext());
1454142932aSMaheshRavishankar populateLinalgNamedOpConversionPatterns(patterns);
1464142932aSMaheshRavishankar if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
1474142932aSMaheshRavishankar return signalPassFailure();
1484142932aSMaheshRavishankar }
1494142932aSMaheshRavishankar };
1504142932aSMaheshRavishankar } // namespace
1514142932aSMaheshRavishankar
populateLinalgNamedOpConversionPatterns(RewritePatternSet & patterns)1524142932aSMaheshRavishankar void mlir::linalg::populateLinalgNamedOpConversionPatterns(
1534142932aSMaheshRavishankar RewritePatternSet &patterns) {
1544142932aSMaheshRavishankar patterns.add<SimplifyDepthwiseConvOp, SimplifyDepthwiseConvQOp>(
1554142932aSMaheshRavishankar patterns.getContext());
1564142932aSMaheshRavishankar }
1574142932aSMaheshRavishankar
createLinalgNamedOpConversionPass()1584142932aSMaheshRavishankar std::unique_ptr<Pass> mlir::createLinalgNamedOpConversionPass() {
1594142932aSMaheshRavishankar return std::make_unique<LinalgNamedOpConversionPass>();
1604142932aSMaheshRavishankar }
161