//===- InitTensorToAllocTensor.cpp - Lower init_tensor to alloc_tensor ----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "PassDetail.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; using namespace mlir::bufferization; using namespace mlir::linalg; namespace { struct InitTensorLoweringPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(InitTensorOp op, PatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, op.getType(), op.sizes()); return success(); } }; struct LinalgInitTensorToAllocTensor : public LinalgInitTensorToAllocTensorBase { LinalgInitTensorToAllocTensor() = default; void runOnOperation() override; void getDependentDialects(DialectRegistry ®istry) const override { registry .insert(); } }; } // namespace void LinalgInitTensorToAllocTensor::runOnOperation() { Operation *op = getOperation(); RewritePatternSet patterns(op->getContext()); patterns.insert(op->getContext()); if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) signalPassFailure(); } std::unique_ptr mlir::createLinalgInitTensorToAllocTensorPass() { return std::make_unique(); }