1 //===- TosaToTensor.cpp - Lowering Tosa to Tensor Dialect -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Tensor dialect.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19
20 using namespace mlir;
21 using namespace tosa;
22
23 namespace {
24
25 class SliceOpConverter : public OpRewritePattern<tosa::SliceOp> {
26 public:
27 using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
28
matchAndRewrite(tosa::SliceOp sliceOp,PatternRewriter & rewriter) const29 LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
30 PatternRewriter &rewriter) const final {
31 Location loc = sliceOp.getLoc();
32 Value input = sliceOp.getInput();
33 SmallVector<int64_t> strides;
34 auto starts = sliceOp.getStart();
35 auto sizes = sliceOp.getSize();
36 strides.resize(sliceOp.getType().template cast<ShapedType>().getRank(), 1);
37
38 SmallVector<Value> dynSizes;
39 for (const auto &i : llvm::enumerate(sizes)) {
40 int64_t size = i.value().cast<IntegerAttr>().getInt();
41 size_t index = i.index();
42 if (size != ShapedType::kDynamicSize)
43 continue;
44
45 auto dim = rewriter.create<tensor::DimOp>(loc, input, index);
46 auto offset = rewriter.create<arith::ConstantOp>(
47 loc,
48 rewriter.getIndexAttr(starts[index].cast<IntegerAttr>().getInt()));
49 dynSizes.push_back(rewriter.create<arith::SubIOp>(loc, dim, offset));
50 }
51
52 auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
53 sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes,
54 ValueRange({}), starts, sizes, rewriter.getI64ArrayAttr(strides));
55
56 rewriter.replaceOp(sliceOp, newSliceOp.getResult());
57 return success();
58 }
59 };
60
61 } // namespace
62
populateTosaToTensorConversionPatterns(RewritePatternSet * patterns)63 void mlir::tosa::populateTosaToTensorConversionPatterns(
64 RewritePatternSet *patterns) {
65 patterns->add<SliceOpConverter>(patterns->getContext());
66 }
67