1 //===- TosaToTensor.cpp - Lowering Tosa to Tensor Dialect -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Tensor dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19 
20 using namespace mlir;
21 using namespace tosa;
22 
23 namespace {
24 
25 class SliceOpConverter : public OpRewritePattern<tosa::SliceOp> {
26 public:
27   using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
28 
matchAndRewrite(tosa::SliceOp sliceOp,PatternRewriter & rewriter) const29   LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
30                                 PatternRewriter &rewriter) const final {
31     Location loc = sliceOp.getLoc();
32     Value input = sliceOp.getInput();
33     SmallVector<int64_t> strides;
34     auto starts = sliceOp.getStart();
35     auto sizes = sliceOp.getSize();
36     strides.resize(sliceOp.getType().template cast<ShapedType>().getRank(), 1);
37 
38     SmallVector<Value> dynSizes;
39     for (const auto &i : llvm::enumerate(sizes)) {
40       int64_t size = i.value().cast<IntegerAttr>().getInt();
41       size_t index = i.index();
42       if (size != ShapedType::kDynamicSize)
43         continue;
44 
45       auto dim = rewriter.create<tensor::DimOp>(loc, input, index);
46       auto offset = rewriter.create<arith::ConstantOp>(
47           loc,
48           rewriter.getIndexAttr(starts[index].cast<IntegerAttr>().getInt()));
49       dynSizes.push_back(rewriter.create<arith::SubIOp>(loc, dim, offset));
50     }
51 
52     auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
53         sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes,
54         ValueRange({}), starts, sizes, rewriter.getI64ArrayAttr(strides));
55 
56     rewriter.replaceOp(sliceOp, newSliceOp.getResult());
57     return success();
58   }
59 };
60 
61 } // namespace
62 
populateTosaToTensorConversionPatterns(RewritePatternSet * patterns)63 void mlir::tosa::populateTosaToTensorConversionPatterns(
64     RewritePatternSet *patterns) {
65   patterns->add<SliceOpConverter>(patterns->getContext());
66 }
67