1 //===- InitTensorToAllocTensor.cpp - Lower init_tensor to alloc_tensor ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/Linalg/Passes.h"
13 #include "mlir/Pass/Pass.h"
14 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
15 
16 using namespace mlir;
17 using namespace mlir::bufferization;
18 using namespace mlir::linalg;
19 
20 namespace {
21 struct InitTensorLoweringPattern : public OpRewritePattern<InitTensorOp> {
22   using OpRewritePattern<InitTensorOp>::OpRewritePattern;
23 
matchAndRewrite__anonb0be89fa0111::InitTensorLoweringPattern24   LogicalResult matchAndRewrite(InitTensorOp op,
25                                 PatternRewriter &rewriter) const override {
26     rewriter.replaceOpWithNewOp<bufferization::AllocTensorOp>(op, op.getType(),
27                                                               op.sizes());
28     return success();
29   }
30 };
31 
32 struct LinalgInitTensorToAllocTensor
33     : public LinalgInitTensorToAllocTensorBase<LinalgInitTensorToAllocTensor> {
34   LinalgInitTensorToAllocTensor() = default;
35 
36   void runOnOperation() override;
37 
getDependentDialects__anonb0be89fa0111::LinalgInitTensorToAllocTensor38   void getDependentDialects(DialectRegistry &registry) const override {
39     registry
40         .insert<linalg::LinalgDialect, bufferization::BufferizationDialect>();
41   }
42 };
43 } // namespace
44 
runOnOperation()45 void LinalgInitTensorToAllocTensor::runOnOperation() {
46   Operation *op = getOperation();
47   RewritePatternSet patterns(op->getContext());
48   patterns.insert<InitTensorLoweringPattern>(op->getContext());
49   if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
50     signalPassFailure();
51 }
52 
createLinalgInitTensorToAllocTensorPass()53 std::unique_ptr<Pass> mlir::createLinalgInitTensorToAllocTensorPass() {
54   return std::make_unique<LinalgInitTensorToAllocTensor>();
55 }
56