1 //===- TosaToTensorPass.cpp - Lowering Tosa to Tensor Dialect -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass legalizes Tosa operations to the Tensor dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "../PassDetail.h"
14 #include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
18 #include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
19 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/Pass/PassManager.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 
25 using namespace mlir;
26 using namespace tosa;
27 
28 namespace {
29 struct TosaToTensor : public TosaToTensorBase<TosaToTensor> {
30 public:
runOnOperation__anon60a245de0111::TosaToTensor31   void runOnOperation() override {
32     RewritePatternSet patterns(&getContext());
33     ConversionTarget target(getContext());
34     target.addIllegalOp<tosa::SliceOp>();
35     target.addLegalDialect<arith::ArithmeticDialect>();
36     target.addLegalDialect<tensor::TensorDialect>();
37 
38     mlir::tosa::populateTosaToTensorConversionPatterns(&patterns);
39 
40     if (failed(applyPartialConversion(getOperation(), target,
41                                       std::move(patterns))))
42       signalPassFailure();
43   }
44 };
45 } // namespace
46 
createTosaToTensor()47 std::unique_ptr<Pass> mlir::tosa::createTosaToTensor() {
48   return std::make_unique<TosaToTensor>();
49 }
50