1*1dce51b8STres Popp //===- TensorToLinalgPass.cpp - Tensor to Linalg Passes -------------------===//
2*1dce51b8STres Popp //
3*1dce51b8STres Popp // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*1dce51b8STres Popp // See https://llvm.org/LICENSE.txt for license information.
5*1dce51b8STres Popp // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*1dce51b8STres Popp //
7*1dce51b8STres Popp //===----------------------------------------------------------------------===//
8*1dce51b8STres Popp //
9*1dce51b8STres Popp // This file implements a pass to convert Tensor dialect to Linalg dialect.
10*1dce51b8STres Popp //
11*1dce51b8STres Popp //===----------------------------------------------------------------------===//
12*1dce51b8STres Popp 
13*1dce51b8STres Popp #include "mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h"
14*1dce51b8STres Popp #include "../PassDetail.h"
15*1dce51b8STres Popp #include "mlir/Conversion/TensorToLinalg/TensorToLinalg.h"
16*1dce51b8STres Popp #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17*1dce51b8STres Popp #include "mlir/Dialect/Linalg/IR/Linalg.h"
18*1dce51b8STres Popp #include "mlir/Dialect/Tensor/IR/Tensor.h"
19*1dce51b8STres Popp 
20*1dce51b8STres Popp using namespace mlir;
21*1dce51b8STres Popp 
22*1dce51b8STres Popp namespace {
23*1dce51b8STres Popp /// A pass converting MLIR Tensor operations into the Linalg dialect.
24*1dce51b8STres Popp class ConvertTensorToLinalgPass
25*1dce51b8STres Popp     : public ConvertTensorToLinalgBase<ConvertTensorToLinalgPass> {
runOnOperation()26*1dce51b8STres Popp   void runOnOperation() override {
27*1dce51b8STres Popp     auto &context = getContext();
28*1dce51b8STres Popp     ConversionTarget target(context);
29*1dce51b8STres Popp     target.addLegalDialect<mlir::arith::ArithmeticDialect,
30*1dce51b8STres Popp                            mlir::linalg::LinalgDialect,
31*1dce51b8STres Popp                            mlir::tensor::TensorDialect>();
32*1dce51b8STres Popp     target.addIllegalOp<mlir::tensor::PadOp>();
33*1dce51b8STres Popp 
34*1dce51b8STres Popp     RewritePatternSet patterns(&context);
35*1dce51b8STres Popp     populateTensorToLinalgPatterns(patterns);
36*1dce51b8STres Popp 
37*1dce51b8STres Popp     if (failed(applyPartialConversion(getOperation(), target,
38*1dce51b8STres Popp                                       std::move(patterns))))
39*1dce51b8STres Popp       return signalPassFailure();
40*1dce51b8STres Popp   }
41*1dce51b8STres Popp };
42*1dce51b8STres Popp } // namespace
43*1dce51b8STres Popp 
44*1dce51b8STres Popp std::unique_ptr<OperationPass<ModuleOp>>
createConvertTensorToLinalgPass()45*1dce51b8STres Popp mlir::createConvertTensorToLinalgPass() {
46*1dce51b8STres Popp   return std::make_unique<ConvertTensorToLinalgPass>();
47*1dce51b8STres Popp }
48