1 //===- SwapExtractSliceWithProducer.cpp - Swapping `tensor.extract_slice` ---=//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Swap a `tensor.extract_slice` with the producer of the source if the producer
10 // implements the `TilingInterface`. When used in conjunction with tiling this
11 // effectively tiles + fuses the producer with its consumer.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
18 #include "mlir/Dialect/Utils/StaticValueUtils.h"
19 #include "mlir/Interfaces/TilingInterface.h"
20
21 using namespace mlir;
22
replaceExtractSliceWithTiledProducer(OpBuilder & builder,tensor::ExtractSliceOp sliceOp,OpResult producer)23 FailureOr<Value> tensor::replaceExtractSliceWithTiledProducer(
24 OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producer) {
25 auto producerOp = dyn_cast<TilingInterface>(producer.getOwner());
26 if (!producerOp)
27 return failure();
28
29 // `TilingInterface` currently only supports strides being 1.
30 if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
31 return !isConstantIntValue(ofr, 1);
32 }))
33 return failure();
34
35 FailureOr<Value> tiledResult = producerOp.generateResultTileValue(
36 builder, producer.getResultNumber(),
37 producerOp.getDestinationOperands(builder), sliceOp.getMixedOffsets(),
38 sliceOp.getMixedSizes(), true);
39 if (failed(tiledResult))
40 return failure();
41
42 return tiledResult.value();
43 }
44