1 //===- NamedOpConversions.cpp - Implements conversions between named ops --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements conversions between named ops that can be seens as
10 // canonicalizations of named ops.
11 //
12 //===----------------------------------------------------------------------===//
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Linalg/IR/Linalg.h"
15 #include "mlir/Dialect/Linalg/Passes.h"
16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19 
20 #include "llvm/ADT/SmallVector.h"
21 
22 using namespace mlir;
23 using namespace mlir::linalg;
24 
getIndicesVector(int start,int end)25 static llvm::SmallVector<int64_t> getIndicesVector(int start, int end) {
26   return llvm::to_vector<2>(llvm::seq<int64_t>(start, end));
27 }
28 
29 static LogicalResult
matchAndReplaceDepthwiseConv(Operation * operation,Value input,Value kernel,Value iZp,Value kZp,Value init,Attribute stride,Attribute dilation,PatternRewriter & rewriter)30 matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel,
31                              Value iZp, Value kZp, Value init, Attribute stride,
32                              Attribute dilation, PatternRewriter &rewriter) {
33   Location loc = operation->getLoc();
34   auto linalgOp = dyn_cast<LinalgOp>(operation);
35   // Exit out on the memref version of this operation.
36   if (!linalgOp || !linalgOp.hasTensorSemantics())
37     return failure();
38 
39   auto result = operation->getResult(0);
40 
41   auto kernelTy = kernel.getType().dyn_cast<RankedTensorType>();
42   auto initTy = init.getType().dyn_cast<RankedTensorType>();
43   auto resultTy = result.getType().template dyn_cast<RankedTensorType>();
44   if (!kernelTy || !initTy || !resultTy)
45     return failure();
46 
47   if (kernelTy.getDimSize(3) != 1)
48     return failure();
49 
50   // Collapse kernel dims.
51   SmallVector<ReassociationIndices, 4> collapsedKernelDims = {
52       getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 4)};
53   auto newKernelTy = RankedTensorType::get(
54       {kernelTy.getDimSize(0), kernelTy.getDimSize(1), kernelTy.getDimSize(2)},
55       kernelTy.getElementType());
56   auto collapsedKernel = rewriter.create<tensor::CollapseShapeOp>(
57       loc, newKernelTy, kernel, collapsedKernelDims);
58 
59   // Collapse init dims.
60   SmallVector<ReassociationIndices, 4> collapsedInitDims = {
61       getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 3),
62       getIndicesVector(3, 5)};
63   auto newInitTy =
64       RankedTensorType::get({initTy.getDimSize(0), initTy.getDimSize(1),
65                              initTy.getDimSize(2), initTy.getDimSize(3)},
66                             initTy.getElementType());
67   auto collapsedInit = rewriter.create<tensor::CollapseShapeOp>(
68       loc, newInitTy, init, collapsedInitDims);
69 
70   Value newConv;
71   if (isa<DepthwiseConv2DNhwcHwcmOp>(operation)) {
72     newConv = rewriter
73                   .create<DepthwiseConv2DNhwcHwcOp>(
74                       loc, newInitTy, ValueRange{input, collapsedKernel},
75                       ValueRange{collapsedInit}, stride, dilation)
76                   .getResult(0);
77   } else if (isa<DepthwiseConv2DNhwcHwcmQOp>(operation)) {
78     newConv =
79         rewriter
80             .create<DepthwiseConv2DNhwcHwcQOp>(
81                 loc, newInitTy, ValueRange{input, collapsedKernel, iZp, kZp},
82                 ValueRange{collapsedInit}, stride, dilation)
83             .getResult(0);
84   }
85 
86   if (!newConv)
87     return failure();
88 
89   // Expand dimensions back out to
90   rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
91       operation, resultTy, newConv, collapsedInitDims);
92   return success();
93 }
94 
95 namespace {
96 struct SimplifyDepthwiseConvOp
97     : public OpRewritePattern<DepthwiseConv2DNhwcHwcmOp> {
98   using OpRewritePattern<DepthwiseConv2DNhwcHwcmOp>::OpRewritePattern;
99 
matchAndRewrite__anon52f9791f0111::SimplifyDepthwiseConvOp100   LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmOp op,
101                                 PatternRewriter &rewriter) const override {
102     Operation *operation = op.getOperation();
103     Value input = op.getInputOperand(0)->get();
104     Value kernel = op.getInputOperand(1)->get();
105     Value init = op.getOutputOperand(0)->get();
106 
107     auto stride = op.strides();
108     auto dilation = op.dilations();
109 
110     return matchAndReplaceDepthwiseConv(operation, input, kernel, nullptr,
111                                         nullptr, init, stride, dilation,
112                                         rewriter);
113   }
114 };
115 
116 struct SimplifyDepthwiseConvQOp
117     : public OpRewritePattern<DepthwiseConv2DNhwcHwcmQOp> {
118   using OpRewritePattern<DepthwiseConv2DNhwcHwcmQOp>::OpRewritePattern;
119 
matchAndRewrite__anon52f9791f0111::SimplifyDepthwiseConvQOp120   LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmQOp op,
121                                 PatternRewriter &rewriter) const override {
122     Operation *operation = op.getOperation();
123     Value input = op.getInputOperand(0)->get();
124     Value kernel = op.getInputOperand(1)->get();
125     Value iZp = op.getInputOperand(2)->get();
126     Value kZp = op.getInputOperand(3)->get();
127     Value init = op.getOutputOperand(0)->get();
128 
129     auto stride = op.strides();
130     auto dilation = op.dilations();
131 
132     return matchAndReplaceDepthwiseConv(operation, input, kernel, iZp, kZp,
133                                         init, stride, dilation, rewriter);
134   }
135 };
136 
137 struct LinalgNamedOpConversionPass
138     : public LinalgNamedOpConversionBase<LinalgNamedOpConversionPass> {
139   LinalgNamedOpConversionPass() = default;
140   LinalgNamedOpConversionPass(const LinalgNamedOpConversionPass &) = default;
141 
runOnOperation__anon52f9791f0111::LinalgNamedOpConversionPass142   void runOnOperation() override {
143     Operation *op = getOperation();
144     RewritePatternSet patterns(op->getContext());
145     populateLinalgNamedOpConversionPatterns(patterns);
146     if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
147       return signalPassFailure();
148   }
149 };
150 } // namespace
151 
populateLinalgNamedOpConversionPatterns(RewritePatternSet & patterns)152 void mlir::linalg::populateLinalgNamedOpConversionPatterns(
153     RewritePatternSet &patterns) {
154   patterns.add<SimplifyDepthwiseConvOp, SimplifyDepthwiseConvQOp>(
155       patterns.getContext());
156 }
157 
createLinalgNamedOpConversionPass()158 std::unique_ptr<Pass> mlir::createLinalgNamedOpConversionPass() {
159   return std::make_unique<LinalgNamedOpConversionPass>();
160 }
161