14142932aSMaheshRavishankar //===- NamedOpConversions.cpp - Implements conversions between named ops --===//
24142932aSMaheshRavishankar //
34142932aSMaheshRavishankar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44142932aSMaheshRavishankar // See https://llvm.org/LICENSE.txt for license information.
54142932aSMaheshRavishankar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
64142932aSMaheshRavishankar //
74142932aSMaheshRavishankar //===----------------------------------------------------------------------===//
84142932aSMaheshRavishankar //
94142932aSMaheshRavishankar // This file implements conversions between named ops that can be seens as
104142932aSMaheshRavishankar // canonicalizations of named ops.
114142932aSMaheshRavishankar //
124142932aSMaheshRavishankar //===----------------------------------------------------------------------===//
134142932aSMaheshRavishankar #include "PassDetail.h"
144142932aSMaheshRavishankar #include "mlir/Dialect/Linalg/IR/Linalg.h"
154142932aSMaheshRavishankar #include "mlir/Dialect/Linalg/Passes.h"
164142932aSMaheshRavishankar #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
174142932aSMaheshRavishankar #include "mlir/IR/PatternMatch.h"
184142932aSMaheshRavishankar #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
194142932aSMaheshRavishankar 
204142932aSMaheshRavishankar #include "llvm/ADT/SmallVector.h"
214142932aSMaheshRavishankar 
224142932aSMaheshRavishankar using namespace mlir;
234142932aSMaheshRavishankar using namespace mlir::linalg;
244142932aSMaheshRavishankar 
getIndicesVector(int start,int end)254142932aSMaheshRavishankar static llvm::SmallVector<int64_t> getIndicesVector(int start, int end) {
264142932aSMaheshRavishankar   return llvm::to_vector<2>(llvm::seq<int64_t>(start, end));
274142932aSMaheshRavishankar }
284142932aSMaheshRavishankar 
294142932aSMaheshRavishankar static LogicalResult
matchAndReplaceDepthwiseConv(Operation * operation,Value input,Value kernel,Value iZp,Value kZp,Value init,Attribute stride,Attribute dilation,PatternRewriter & rewriter)304142932aSMaheshRavishankar matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel,
314142932aSMaheshRavishankar                              Value iZp, Value kZp, Value init, Attribute stride,
324142932aSMaheshRavishankar                              Attribute dilation, PatternRewriter &rewriter) {
334142932aSMaheshRavishankar   Location loc = operation->getLoc();
344142932aSMaheshRavishankar   auto linalgOp = dyn_cast<LinalgOp>(operation);
354142932aSMaheshRavishankar   // Exit out on the memref version of this operation.
364142932aSMaheshRavishankar   if (!linalgOp || !linalgOp.hasTensorSemantics())
374142932aSMaheshRavishankar     return failure();
384142932aSMaheshRavishankar 
394142932aSMaheshRavishankar   auto result = operation->getResult(0);
404142932aSMaheshRavishankar 
414142932aSMaheshRavishankar   auto kernelTy = kernel.getType().dyn_cast<RankedTensorType>();
424142932aSMaheshRavishankar   auto initTy = init.getType().dyn_cast<RankedTensorType>();
434142932aSMaheshRavishankar   auto resultTy = result.getType().template dyn_cast<RankedTensorType>();
444142932aSMaheshRavishankar   if (!kernelTy || !initTy || !resultTy)
454142932aSMaheshRavishankar     return failure();
464142932aSMaheshRavishankar 
474142932aSMaheshRavishankar   if (kernelTy.getDimSize(3) != 1)
484142932aSMaheshRavishankar     return failure();
494142932aSMaheshRavishankar 
504142932aSMaheshRavishankar   // Collapse kernel dims.
514142932aSMaheshRavishankar   SmallVector<ReassociationIndices, 4> collapsedKernelDims = {
524142932aSMaheshRavishankar       getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 4)};
534142932aSMaheshRavishankar   auto newKernelTy = RankedTensorType::get(
544142932aSMaheshRavishankar       {kernelTy.getDimSize(0), kernelTy.getDimSize(1), kernelTy.getDimSize(2)},
554142932aSMaheshRavishankar       kernelTy.getElementType());
564142932aSMaheshRavishankar   auto collapsedKernel = rewriter.create<tensor::CollapseShapeOp>(
574142932aSMaheshRavishankar       loc, newKernelTy, kernel, collapsedKernelDims);
584142932aSMaheshRavishankar 
594142932aSMaheshRavishankar   // Collapse init dims.
604142932aSMaheshRavishankar   SmallVector<ReassociationIndices, 4> collapsedInitDims = {
614142932aSMaheshRavishankar       getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 3),
624142932aSMaheshRavishankar       getIndicesVector(3, 5)};
634142932aSMaheshRavishankar   auto newInitTy =
644142932aSMaheshRavishankar       RankedTensorType::get({initTy.getDimSize(0), initTy.getDimSize(1),
654142932aSMaheshRavishankar                              initTy.getDimSize(2), initTy.getDimSize(3)},
664142932aSMaheshRavishankar                             initTy.getElementType());
674142932aSMaheshRavishankar   auto collapsedInit = rewriter.create<tensor::CollapseShapeOp>(
684142932aSMaheshRavishankar       loc, newInitTy, init, collapsedInitDims);
694142932aSMaheshRavishankar 
704142932aSMaheshRavishankar   Value newConv;
714142932aSMaheshRavishankar   if (isa<DepthwiseConv2DNhwcHwcmOp>(operation)) {
724142932aSMaheshRavishankar     newConv = rewriter
734142932aSMaheshRavishankar                   .create<DepthwiseConv2DNhwcHwcOp>(
744142932aSMaheshRavishankar                       loc, newInitTy, ValueRange{input, collapsedKernel},
754142932aSMaheshRavishankar                       ValueRange{collapsedInit}, stride, dilation)
764142932aSMaheshRavishankar                   .getResult(0);
774142932aSMaheshRavishankar   } else if (isa<DepthwiseConv2DNhwcHwcmQOp>(operation)) {
784142932aSMaheshRavishankar     newConv =
794142932aSMaheshRavishankar         rewriter
804142932aSMaheshRavishankar             .create<DepthwiseConv2DNhwcHwcQOp>(
814142932aSMaheshRavishankar                 loc, newInitTy, ValueRange{input, collapsedKernel, iZp, kZp},
824142932aSMaheshRavishankar                 ValueRange{collapsedInit}, stride, dilation)
834142932aSMaheshRavishankar             .getResult(0);
844142932aSMaheshRavishankar   }
854142932aSMaheshRavishankar 
864142932aSMaheshRavishankar   if (!newConv)
874142932aSMaheshRavishankar     return failure();
884142932aSMaheshRavishankar 
894142932aSMaheshRavishankar   // Expand dimensions back out to
904142932aSMaheshRavishankar   rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
914142932aSMaheshRavishankar       operation, resultTy, newConv, collapsedInitDims);
924142932aSMaheshRavishankar   return success();
934142932aSMaheshRavishankar }
944142932aSMaheshRavishankar 
954142932aSMaheshRavishankar namespace {
964142932aSMaheshRavishankar struct SimplifyDepthwiseConvOp
974142932aSMaheshRavishankar     : public OpRewritePattern<DepthwiseConv2DNhwcHwcmOp> {
984142932aSMaheshRavishankar   using OpRewritePattern<DepthwiseConv2DNhwcHwcmOp>::OpRewritePattern;
994142932aSMaheshRavishankar 
matchAndRewrite__anon52f9791f0111::SimplifyDepthwiseConvOp1004142932aSMaheshRavishankar   LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmOp op,
1014142932aSMaheshRavishankar                                 PatternRewriter &rewriter) const override {
1024142932aSMaheshRavishankar     Operation *operation = op.getOperation();
1034142932aSMaheshRavishankar     Value input = op.getInputOperand(0)->get();
1044142932aSMaheshRavishankar     Value kernel = op.getInputOperand(1)->get();
1054142932aSMaheshRavishankar     Value init = op.getOutputOperand(0)->get();
1064142932aSMaheshRavishankar 
1074142932aSMaheshRavishankar     auto stride = op.strides();
1084142932aSMaheshRavishankar     auto dilation = op.dilations();
1094142932aSMaheshRavishankar 
1104142932aSMaheshRavishankar     return matchAndReplaceDepthwiseConv(operation, input, kernel, nullptr,
1114142932aSMaheshRavishankar                                         nullptr, init, stride, dilation,
1124142932aSMaheshRavishankar                                         rewriter);
1134142932aSMaheshRavishankar   }
1144142932aSMaheshRavishankar };
1154142932aSMaheshRavishankar 
1164142932aSMaheshRavishankar struct SimplifyDepthwiseConvQOp
1174142932aSMaheshRavishankar     : public OpRewritePattern<DepthwiseConv2DNhwcHwcmQOp> {
1184142932aSMaheshRavishankar   using OpRewritePattern<DepthwiseConv2DNhwcHwcmQOp>::OpRewritePattern;
1194142932aSMaheshRavishankar 
matchAndRewrite__anon52f9791f0111::SimplifyDepthwiseConvQOp1204142932aSMaheshRavishankar   LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmQOp op,
1214142932aSMaheshRavishankar                                 PatternRewriter &rewriter) const override {
1224142932aSMaheshRavishankar     Operation *operation = op.getOperation();
1234142932aSMaheshRavishankar     Value input = op.getInputOperand(0)->get();
1244142932aSMaheshRavishankar     Value kernel = op.getInputOperand(1)->get();
1254142932aSMaheshRavishankar     Value iZp = op.getInputOperand(2)->get();
1264142932aSMaheshRavishankar     Value kZp = op.getInputOperand(3)->get();
1274142932aSMaheshRavishankar     Value init = op.getOutputOperand(0)->get();
1284142932aSMaheshRavishankar 
1294142932aSMaheshRavishankar     auto stride = op.strides();
1304142932aSMaheshRavishankar     auto dilation = op.dilations();
1314142932aSMaheshRavishankar 
1324142932aSMaheshRavishankar     return matchAndReplaceDepthwiseConv(operation, input, kernel, iZp, kZp,
1334142932aSMaheshRavishankar                                         init, stride, dilation, rewriter);
1344142932aSMaheshRavishankar   }
1354142932aSMaheshRavishankar };
1364142932aSMaheshRavishankar 
1374142932aSMaheshRavishankar struct LinalgNamedOpConversionPass
1384142932aSMaheshRavishankar     : public LinalgNamedOpConversionBase<LinalgNamedOpConversionPass> {
1394142932aSMaheshRavishankar   LinalgNamedOpConversionPass() = default;
140*9a7d111fSNicolas Vasilache   LinalgNamedOpConversionPass(const LinalgNamedOpConversionPass &) = default;
1414142932aSMaheshRavishankar 
runOnOperation__anon52f9791f0111::LinalgNamedOpConversionPass1424142932aSMaheshRavishankar   void runOnOperation() override {
1434142932aSMaheshRavishankar     Operation *op = getOperation();
1444142932aSMaheshRavishankar     RewritePatternSet patterns(op->getContext());
1454142932aSMaheshRavishankar     populateLinalgNamedOpConversionPatterns(patterns);
1464142932aSMaheshRavishankar     if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
1474142932aSMaheshRavishankar       return signalPassFailure();
1484142932aSMaheshRavishankar   }
1494142932aSMaheshRavishankar };
1504142932aSMaheshRavishankar } // namespace
1514142932aSMaheshRavishankar 
populateLinalgNamedOpConversionPatterns(RewritePatternSet & patterns)1524142932aSMaheshRavishankar void mlir::linalg::populateLinalgNamedOpConversionPatterns(
1534142932aSMaheshRavishankar     RewritePatternSet &patterns) {
1544142932aSMaheshRavishankar   patterns.add<SimplifyDepthwiseConvOp, SimplifyDepthwiseConvQOp>(
1554142932aSMaheshRavishankar       patterns.getContext());
1564142932aSMaheshRavishankar }
1574142932aSMaheshRavishankar 
createLinalgNamedOpConversionPass()1584142932aSMaheshRavishankar std::unique_ptr<Pass> mlir::createLinalgNamedOpConversionPass() {
1594142932aSMaheshRavishankar   return std::make_unique<LinalgNamedOpConversionPass>();
1604142932aSMaheshRavishankar }
161