12f637fe7SMahesh Ravishankar //===- SwapExtractSliceWithProducer.cpp - Swapping `tensor.extract_slice` ---=//
22f637fe7SMahesh Ravishankar //
32f637fe7SMahesh Ravishankar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42f637fe7SMahesh Ravishankar // See https://llvm.org/LICENSE.txt for license information.
52f637fe7SMahesh Ravishankar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
62f637fe7SMahesh Ravishankar //
72f637fe7SMahesh Ravishankar //===----------------------------------------------------------------------===//
82f637fe7SMahesh Ravishankar //
92f637fe7SMahesh Ravishankar // Swap a `tensor.extract_slice` with the producer of the source if the producer
102f637fe7SMahesh Ravishankar // implements the `TilingInterface`. When used in conjunction with tiling this
112f637fe7SMahesh Ravishankar // effectively tiles + fuses the producer with its consumer.
122f637fe7SMahesh Ravishankar //
132f637fe7SMahesh Ravishankar //===----------------------------------------------------------------------===//
142f637fe7SMahesh Ravishankar
152f637fe7SMahesh Ravishankar #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
162f637fe7SMahesh Ravishankar #include "mlir/Dialect/Tensor/IR/Tensor.h"
172f637fe7SMahesh Ravishankar #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
182f637fe7SMahesh Ravishankar #include "mlir/Dialect/Utils/StaticValueUtils.h"
192f637fe7SMahesh Ravishankar #include "mlir/Interfaces/TilingInterface.h"
202f637fe7SMahesh Ravishankar
212f637fe7SMahesh Ravishankar using namespace mlir;
222f637fe7SMahesh Ravishankar
replaceExtractSliceWithTiledProducer(OpBuilder & builder,tensor::ExtractSliceOp sliceOp,OpResult producer)232f637fe7SMahesh Ravishankar FailureOr<Value> tensor::replaceExtractSliceWithTiledProducer(
242f637fe7SMahesh Ravishankar OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producer) {
252f637fe7SMahesh Ravishankar auto producerOp = dyn_cast<TilingInterface>(producer.getOwner());
262f637fe7SMahesh Ravishankar if (!producerOp)
272f637fe7SMahesh Ravishankar return failure();
282f637fe7SMahesh Ravishankar
292f637fe7SMahesh Ravishankar // `TilingInterface` currently only supports strides being 1.
302f637fe7SMahesh Ravishankar if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
312f637fe7SMahesh Ravishankar return !isConstantIntValue(ofr, 1);
322f637fe7SMahesh Ravishankar }))
332f637fe7SMahesh Ravishankar return failure();
342f637fe7SMahesh Ravishankar
352f637fe7SMahesh Ravishankar FailureOr<Value> tiledResult = producerOp.generateResultTileValue(
362f637fe7SMahesh Ravishankar builder, producer.getResultNumber(),
372f637fe7SMahesh Ravishankar producerOp.getDestinationOperands(builder), sliceOp.getMixedOffsets(),
382f637fe7SMahesh Ravishankar sliceOp.getMixedSizes(), true);
392f637fe7SMahesh Ravishankar if (failed(tiledResult))
402f637fe7SMahesh Ravishankar return failure();
412f637fe7SMahesh Ravishankar
42*c27d8152SKazu Hirata return tiledResult.value();
432f637fe7SMahesh Ravishankar }
44