1 //===- TosaToTensor.cpp - Lowering Tosa to Tensor Dialect -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // These rewriters lower from the Tosa to the Tensor dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/TosaToTensor/TosaToTensor.h" 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/Dialect/Tosa/IR/TosaOps.h" 17 #include "mlir/IR/PatternMatch.h" 18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 19 20 using namespace mlir; 21 using namespace tosa; 22 23 namespace { 24 25 class SliceOpConverter : public OpRewritePattern<tosa::SliceOp> { 26 public: 27 using OpRewritePattern<tosa::SliceOp>::OpRewritePattern; 28 29 LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, 30 PatternRewriter &rewriter) const final { 31 Location loc = sliceOp.getLoc(); 32 Value input = sliceOp.getInput(); 33 SmallVector<int64_t> strides; 34 auto starts = sliceOp.getStart(); 35 auto sizes = sliceOp.getSize(); 36 strides.resize(sliceOp.getType().template cast<ShapedType>().getRank(), 1); 37 38 SmallVector<Value> dynSizes; 39 for (const auto &i : llvm::enumerate(sizes)) { 40 int64_t size = i.value().cast<IntegerAttr>().getInt(); 41 size_t index = i.index(); 42 if (size != ShapedType::kDynamicSize) 43 continue; 44 45 auto dim = rewriter.create<tensor::DimOp>(loc, input, index); 46 auto offset = rewriter.create<arith::ConstantOp>( 47 loc, 48 rewriter.getIndexAttr(starts[index].cast<IntegerAttr>().getInt())); 49 dynSizes.push_back(rewriter.create<arith::SubIOp>(loc, dim, offset)); 50 } 51 52 auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>( 53 sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes, 54 ValueRange({}), starts, sizes, rewriter.getI64ArrayAttr(strides)); 55 56 rewriter.replaceOp(sliceOp, newSliceOp.getResult()); 57 return success(); 58 } 59 }; 60 61 } // namespace 62 63 void mlir::tosa::populateTosaToTensorConversionPatterns( 64 RewritePatternSet *patterns) { 65 patterns->add<SliceOpConverter>(patterns->getContext()); 66 } 67