1 //===- TensorToLinalgPass.cpp - Tensor to Linalg Passes -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert Tensor dialect to Linalg dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h" 14 #include "../PassDetail.h" 15 #include "mlir/Conversion/TensorToLinalg/TensorToLinalg.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Linalg/IR/Linalg.h" 18 #include "mlir/Dialect/Tensor/IR/Tensor.h" 19 20 using namespace mlir; 21 22 namespace { 23 /// A pass converting MLIR Tensor operations into the Linalg dialect. 24 class ConvertTensorToLinalgPass 25 : public ConvertTensorToLinalgBase<ConvertTensorToLinalgPass> { runOnOperation()26 void runOnOperation() override { 27 auto &context = getContext(); 28 ConversionTarget target(context); 29 target.addLegalDialect<mlir::arith::ArithmeticDialect, 30 mlir::linalg::LinalgDialect, 31 mlir::tensor::TensorDialect>(); 32 target.addIllegalOp<mlir::tensor::PadOp>(); 33 34 RewritePatternSet patterns(&context); 35 populateTensorToLinalgPatterns(patterns); 36 37 if (failed(applyPartialConversion(getOperation(), target, 38 std::move(patterns)))) 39 return signalPassFailure(); 40 } 41 }; 42 } // namespace 43 44 std::unique_ptr<OperationPass<ModuleOp>> createConvertTensorToLinalgPass()45mlir::createConvertTensorToLinalgPass() { 46 return std::make_unique<ConvertTensorToLinalgPass>(); 47 } 48