1 //===- NamedOpConversions.cpp - Implements conversions between named ops --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements conversions between named ops that can be seens as
10 // canonicalizations of named ops.
11 //
12 //===----------------------------------------------------------------------===//
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Linalg/IR/Linalg.h"
15 #include "mlir/Dialect/Linalg/Passes.h"
16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19
20 #include "llvm/ADT/SmallVector.h"
21
22 using namespace mlir;
23 using namespace mlir::linalg;
24
getIndicesVector(int start,int end)25 static llvm::SmallVector<int64_t> getIndicesVector(int start, int end) {
26 return llvm::to_vector<2>(llvm::seq<int64_t>(start, end));
27 }
28
29 static LogicalResult
matchAndReplaceDepthwiseConv(Operation * operation,Value input,Value kernel,Value iZp,Value kZp,Value init,Attribute stride,Attribute dilation,PatternRewriter & rewriter)30 matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel,
31 Value iZp, Value kZp, Value init, Attribute stride,
32 Attribute dilation, PatternRewriter &rewriter) {
33 Location loc = operation->getLoc();
34 auto linalgOp = dyn_cast<LinalgOp>(operation);
35 // Exit out on the memref version of this operation.
36 if (!linalgOp || !linalgOp.hasTensorSemantics())
37 return failure();
38
39 auto result = operation->getResult(0);
40
41 auto kernelTy = kernel.getType().dyn_cast<RankedTensorType>();
42 auto initTy = init.getType().dyn_cast<RankedTensorType>();
43 auto resultTy = result.getType().template dyn_cast<RankedTensorType>();
44 if (!kernelTy || !initTy || !resultTy)
45 return failure();
46
47 if (kernelTy.getDimSize(3) != 1)
48 return failure();
49
50 // Collapse kernel dims.
51 SmallVector<ReassociationIndices, 4> collapsedKernelDims = {
52 getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 4)};
53 auto newKernelTy = RankedTensorType::get(
54 {kernelTy.getDimSize(0), kernelTy.getDimSize(1), kernelTy.getDimSize(2)},
55 kernelTy.getElementType());
56 auto collapsedKernel = rewriter.create<tensor::CollapseShapeOp>(
57 loc, newKernelTy, kernel, collapsedKernelDims);
58
59 // Collapse init dims.
60 SmallVector<ReassociationIndices, 4> collapsedInitDims = {
61 getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 3),
62 getIndicesVector(3, 5)};
63 auto newInitTy =
64 RankedTensorType::get({initTy.getDimSize(0), initTy.getDimSize(1),
65 initTy.getDimSize(2), initTy.getDimSize(3)},
66 initTy.getElementType());
67 auto collapsedInit = rewriter.create<tensor::CollapseShapeOp>(
68 loc, newInitTy, init, collapsedInitDims);
69
70 Value newConv;
71 if (isa<DepthwiseConv2DNhwcHwcmOp>(operation)) {
72 newConv = rewriter
73 .create<DepthwiseConv2DNhwcHwcOp>(
74 loc, newInitTy, ValueRange{input, collapsedKernel},
75 ValueRange{collapsedInit}, stride, dilation)
76 .getResult(0);
77 } else if (isa<DepthwiseConv2DNhwcHwcmQOp>(operation)) {
78 newConv =
79 rewriter
80 .create<DepthwiseConv2DNhwcHwcQOp>(
81 loc, newInitTy, ValueRange{input, collapsedKernel, iZp, kZp},
82 ValueRange{collapsedInit}, stride, dilation)
83 .getResult(0);
84 }
85
86 if (!newConv)
87 return failure();
88
89 // Expand dimensions back out to
90 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
91 operation, resultTy, newConv, collapsedInitDims);
92 return success();
93 }
94
95 namespace {
96 struct SimplifyDepthwiseConvOp
97 : public OpRewritePattern<DepthwiseConv2DNhwcHwcmOp> {
98 using OpRewritePattern<DepthwiseConv2DNhwcHwcmOp>::OpRewritePattern;
99
matchAndRewrite__anon52f9791f0111::SimplifyDepthwiseConvOp100 LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmOp op,
101 PatternRewriter &rewriter) const override {
102 Operation *operation = op.getOperation();
103 Value input = op.getInputOperand(0)->get();
104 Value kernel = op.getInputOperand(1)->get();
105 Value init = op.getOutputOperand(0)->get();
106
107 auto stride = op.strides();
108 auto dilation = op.dilations();
109
110 return matchAndReplaceDepthwiseConv(operation, input, kernel, nullptr,
111 nullptr, init, stride, dilation,
112 rewriter);
113 }
114 };
115
116 struct SimplifyDepthwiseConvQOp
117 : public OpRewritePattern<DepthwiseConv2DNhwcHwcmQOp> {
118 using OpRewritePattern<DepthwiseConv2DNhwcHwcmQOp>::OpRewritePattern;
119
matchAndRewrite__anon52f9791f0111::SimplifyDepthwiseConvQOp120 LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmQOp op,
121 PatternRewriter &rewriter) const override {
122 Operation *operation = op.getOperation();
123 Value input = op.getInputOperand(0)->get();
124 Value kernel = op.getInputOperand(1)->get();
125 Value iZp = op.getInputOperand(2)->get();
126 Value kZp = op.getInputOperand(3)->get();
127 Value init = op.getOutputOperand(0)->get();
128
129 auto stride = op.strides();
130 auto dilation = op.dilations();
131
132 return matchAndReplaceDepthwiseConv(operation, input, kernel, iZp, kZp,
133 init, stride, dilation, rewriter);
134 }
135 };
136
137 struct LinalgNamedOpConversionPass
138 : public LinalgNamedOpConversionBase<LinalgNamedOpConversionPass> {
139 LinalgNamedOpConversionPass() = default;
140 LinalgNamedOpConversionPass(const LinalgNamedOpConversionPass &) = default;
141
runOnOperation__anon52f9791f0111::LinalgNamedOpConversionPass142 void runOnOperation() override {
143 Operation *op = getOperation();
144 RewritePatternSet patterns(op->getContext());
145 populateLinalgNamedOpConversionPatterns(patterns);
146 if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
147 return signalPassFailure();
148 }
149 };
150 } // namespace
151
populateLinalgNamedOpConversionPatterns(RewritePatternSet & patterns)152 void mlir::linalg::populateLinalgNamedOpConversionPatterns(
153 RewritePatternSet &patterns) {
154 patterns.add<SimplifyDepthwiseConvOp, SimplifyDepthwiseConvQOp>(
155 patterns.getContext());
156 }
157
createLinalgNamedOpConversionPass()158 std::unique_ptr<Pass> mlir::createLinalgNamedOpConversionPass() {
159 return std::make_unique<LinalgNamedOpConversionPass>();
160 }
161