1*1dce51b8STres Popp //===- TensorToLinalgPass.cpp - Tensor to Linalg Passes -------------------===// 2*1dce51b8STres Popp // 3*1dce51b8STres Popp // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*1dce51b8STres Popp // See https://llvm.org/LICENSE.txt for license information. 5*1dce51b8STres Popp // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*1dce51b8STres Popp // 7*1dce51b8STres Popp //===----------------------------------------------------------------------===// 8*1dce51b8STres Popp // 9*1dce51b8STres Popp // This file implements a pass to convert Tensor dialect to Linalg dialect. 10*1dce51b8STres Popp // 11*1dce51b8STres Popp //===----------------------------------------------------------------------===// 12*1dce51b8STres Popp 13*1dce51b8STres Popp #include "mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h" 14*1dce51b8STres Popp #include "../PassDetail.h" 15*1dce51b8STres Popp #include "mlir/Conversion/TensorToLinalg/TensorToLinalg.h" 16*1dce51b8STres Popp #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17*1dce51b8STres Popp #include "mlir/Dialect/Linalg/IR/Linalg.h" 18*1dce51b8STres Popp #include "mlir/Dialect/Tensor/IR/Tensor.h" 19*1dce51b8STres Popp 20*1dce51b8STres Popp using namespace mlir; 21*1dce51b8STres Popp 22*1dce51b8STres Popp namespace { 23*1dce51b8STres Popp /// A pass converting MLIR Tensor operations into the Linalg dialect. 24*1dce51b8STres Popp class ConvertTensorToLinalgPass 25*1dce51b8STres Popp : public ConvertTensorToLinalgBase<ConvertTensorToLinalgPass> { runOnOperation()26*1dce51b8STres Popp void runOnOperation() override { 27*1dce51b8STres Popp auto &context = getContext(); 28*1dce51b8STres Popp ConversionTarget target(context); 29*1dce51b8STres Popp target.addLegalDialect<mlir::arith::ArithmeticDialect, 30*1dce51b8STres Popp mlir::linalg::LinalgDialect, 31*1dce51b8STres Popp mlir::tensor::TensorDialect>(); 32*1dce51b8STres Popp target.addIllegalOp<mlir::tensor::PadOp>(); 33*1dce51b8STres Popp 34*1dce51b8STres Popp RewritePatternSet patterns(&context); 35*1dce51b8STres Popp populateTensorToLinalgPatterns(patterns); 36*1dce51b8STres Popp 37*1dce51b8STres Popp if (failed(applyPartialConversion(getOperation(), target, 38*1dce51b8STres Popp std::move(patterns)))) 39*1dce51b8STres Popp return signalPassFailure(); 40*1dce51b8STres Popp } 41*1dce51b8STres Popp }; 42*1dce51b8STres Popp } // namespace 43*1dce51b8STres Popp 44*1dce51b8STres Popp std::unique_ptr<OperationPass<ModuleOp>> createConvertTensorToLinalgPass()45*1dce51b8STres Poppmlir::createConvertTensorToLinalgPass() { 46*1dce51b8STres Popp return std::make_unique<ConvertTensorToLinalgPass>(); 47*1dce51b8STres Popp } 48