1 //===- NamedOpConversions.cpp - Implements conversions between named ops --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements conversions between named ops that can be seens as 10 // canonicalizations of named ops. 11 // 12 //===----------------------------------------------------------------------===// 13 #include "PassDetail.h" 14 #include "mlir/Dialect/Linalg/IR/Linalg.h" 15 #include "mlir/Dialect/Linalg/Passes.h" 16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 17 #include "mlir/IR/PatternMatch.h" 18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 19 20 #include "llvm/ADT/SmallVector.h" 21 22 using namespace mlir; 23 using namespace mlir::linalg; 24 25 static llvm::SmallVector<int64_t> getIndicesVector(int start, int end) { 26 return llvm::to_vector<2>(llvm::seq<int64_t>(start, end)); 27 } 28 29 static LogicalResult 30 matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel, 31 Value iZp, Value kZp, Value init, Attribute stride, 32 Attribute dilation, PatternRewriter &rewriter) { 33 Location loc = operation->getLoc(); 34 auto linalgOp = dyn_cast<LinalgOp>(operation); 35 // Exit out on the memref version of this operation. 36 if (!linalgOp || !linalgOp.hasTensorSemantics()) 37 return failure(); 38 39 auto result = operation->getResult(0); 40 41 auto kernelTy = kernel.getType().dyn_cast<RankedTensorType>(); 42 auto initTy = init.getType().dyn_cast<RankedTensorType>(); 43 auto resultTy = result.getType().template dyn_cast<RankedTensorType>(); 44 if (!kernelTy || !initTy || !resultTy) 45 return failure(); 46 47 if (kernelTy.getDimSize(3) != 1) 48 return failure(); 49 50 // Collapse kernel dims. 51 SmallVector<ReassociationIndices, 4> collapsedKernelDims = { 52 getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 4)}; 53 auto newKernelTy = RankedTensorType::get( 54 {kernelTy.getDimSize(0), kernelTy.getDimSize(1), kernelTy.getDimSize(2)}, 55 kernelTy.getElementType()); 56 auto collapsedKernel = rewriter.create<tensor::CollapseShapeOp>( 57 loc, newKernelTy, kernel, collapsedKernelDims); 58 59 // Collapse init dims. 60 SmallVector<ReassociationIndices, 4> collapsedInitDims = { 61 getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 3), 62 getIndicesVector(3, 5)}; 63 auto newInitTy = 64 RankedTensorType::get({initTy.getDimSize(0), initTy.getDimSize(1), 65 initTy.getDimSize(2), initTy.getDimSize(3)}, 66 initTy.getElementType()); 67 auto collapsedInit = rewriter.create<tensor::CollapseShapeOp>( 68 loc, newInitTy, init, collapsedInitDims); 69 70 Value newConv; 71 if (isa<DepthwiseConv2DNhwcHwcmOp>(operation)) { 72 newConv = rewriter 73 .create<DepthwiseConv2DNhwcHwcOp>( 74 loc, newInitTy, ValueRange{input, collapsedKernel}, 75 ValueRange{collapsedInit}, stride, dilation) 76 .getResult(0); 77 } else if (isa<DepthwiseConv2DNhwcHwcmQOp>(operation)) { 78 newConv = 79 rewriter 80 .create<DepthwiseConv2DNhwcHwcQOp>( 81 loc, newInitTy, ValueRange{input, collapsedKernel, iZp, kZp}, 82 ValueRange{collapsedInit}, stride, dilation) 83 .getResult(0); 84 } 85 86 if (!newConv) 87 return failure(); 88 89 // Expand dimensions back out to 90 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>( 91 operation, resultTy, newConv, collapsedInitDims); 92 return success(); 93 } 94 95 namespace { 96 struct SimplifyDepthwiseConvOp 97 : public OpRewritePattern<DepthwiseConv2DNhwcHwcmOp> { 98 using OpRewritePattern<DepthwiseConv2DNhwcHwcmOp>::OpRewritePattern; 99 100 LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmOp op, 101 PatternRewriter &rewriter) const override { 102 Operation *operation = op.getOperation(); 103 Value input = op.getInputOperand(0)->get(); 104 Value kernel = op.getInputOperand(1)->get(); 105 Value init = op.getOutputOperand(0)->get(); 106 107 auto stride = op.strides(); 108 auto dilation = op.dilations(); 109 110 return matchAndReplaceDepthwiseConv(operation, input, kernel, nullptr, 111 nullptr, init, stride, dilation, 112 rewriter); 113 } 114 }; 115 116 struct SimplifyDepthwiseConvQOp 117 : public OpRewritePattern<DepthwiseConv2DNhwcHwcmQOp> { 118 using OpRewritePattern<DepthwiseConv2DNhwcHwcmQOp>::OpRewritePattern; 119 120 LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmQOp op, 121 PatternRewriter &rewriter) const override { 122 Operation *operation = op.getOperation(); 123 Value input = op.getInputOperand(0)->get(); 124 Value kernel = op.getInputOperand(1)->get(); 125 Value iZp = op.getInputOperand(2)->get(); 126 Value kZp = op.getInputOperand(3)->get(); 127 Value init = op.getOutputOperand(0)->get(); 128 129 auto stride = op.strides(); 130 auto dilation = op.dilations(); 131 132 return matchAndReplaceDepthwiseConv(operation, input, kernel, iZp, kZp, 133 init, stride, dilation, rewriter); 134 } 135 }; 136 137 struct LinalgNamedOpConversionPass 138 : public LinalgNamedOpConversionBase<LinalgNamedOpConversionPass> { 139 LinalgNamedOpConversionPass() = default; 140 LinalgNamedOpConversionPass(const LinalgNamedOpConversionPass &) = default; 141 142 void runOnOperation() override { 143 Operation *op = getOperation(); 144 RewritePatternSet patterns(op->getContext()); 145 populateLinalgNamedOpConversionPatterns(patterns); 146 if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) 147 return signalPassFailure(); 148 } 149 }; 150 } // namespace 151 152 void mlir::linalg::populateLinalgNamedOpConversionPatterns( 153 RewritePatternSet &patterns) { 154 patterns.add<SimplifyDepthwiseConvOp, SimplifyDepthwiseConvQOp>( 155 patterns.getContext()); 156 } 157 158 std::unique_ptr<Pass> mlir::createLinalgNamedOpConversionPass() { 159 return std::make_unique<LinalgNamedOpConversionPass>(); 160 } 161