1cf6a7c19SMahesh Ravishankar //===- TilingInterfaceImpl.cpp - Implementation of TilingInterface -------===//
2cf6a7c19SMahesh Ravishankar //
3cf6a7c19SMahesh Ravishankar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4cf6a7c19SMahesh Ravishankar // See https://llvm.org/LICENSE.txt for license information.
5cf6a7c19SMahesh Ravishankar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6cf6a7c19SMahesh Ravishankar //
7cf6a7c19SMahesh Ravishankar //===----------------------------------------------------------------------===//
8cf6a7c19SMahesh Ravishankar 
9cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
10cf6a7c19SMahesh Ravishankar 
11cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Affine/IR/AffineOps.h"
12cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
14cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Linalg/IR/Linalg.h"
15cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Linalg/Utils/Utils.h"
16cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Tensor/IR/Tensor.h"
17cf6a7c19SMahesh Ravishankar #include "mlir/Interfaces/TilingInterface.h"
18cf6a7c19SMahesh Ravishankar 
19cf6a7c19SMahesh Ravishankar using namespace mlir;
20cf6a7c19SMahesh Ravishankar using namespace mlir::linalg;
21cf6a7c19SMahesh Ravishankar 
22cf6a7c19SMahesh Ravishankar namespace {
23cf6a7c19SMahesh Ravishankar 
24cf6a7c19SMahesh Ravishankar /// External model implementation of TilingInterface for LinalgOps. An external
25cf6a7c19SMahesh Ravishankar /// model implementation is used for now till the use of `TilingInterface` is
26cf6a7c19SMahesh Ravishankar /// on-par with the current Linalg tiling + fusion patterns. Once it is
27cf6a7c19SMahesh Ravishankar /// maybe possible to move this into the op-definition (though there are
28cf6a7c19SMahesh Ravishankar /// advantages to leaving it as an external model)
29cf6a7c19SMahesh Ravishankar template <typename LinalgOpTy>
30cf6a7c19SMahesh Ravishankar struct LinalgOpTilingInterface
31cf6a7c19SMahesh Ravishankar     : public TilingInterface::ExternalModel<LinalgOpTilingInterface<LinalgOpTy>,
32cf6a7c19SMahesh Ravishankar                                             LinalgOpTy> {
33cf6a7c19SMahesh Ravishankar   /// Return the destination operands.
getDestinationOperands__anon8010582c0111::LinalgOpTilingInterface34cf6a7c19SMahesh Ravishankar   SmallVector<Value> getDestinationOperands(Operation *op, OpBuilder &b) const {
35cf6a7c19SMahesh Ravishankar     return llvm::cast<LinalgOp>(op).getOutputOperands();
36cf6a7c19SMahesh Ravishankar   }
37cf6a7c19SMahesh Ravishankar 
38cf6a7c19SMahesh Ravishankar   /// Return the loop iterator type.
getLoopIteratorTypes__anon8010582c0111::LinalgOpTilingInterface39cf6a7c19SMahesh Ravishankar   SmallVector<StringRef> getLoopIteratorTypes(Operation *op) const {
40cf6a7c19SMahesh Ravishankar     LinalgOpTy concreteOp = cast<LinalgOpTy>(op);
41cf6a7c19SMahesh Ravishankar     return llvm::to_vector(
42cf6a7c19SMahesh Ravishankar         llvm::map_range(concreteOp.iterator_types(), [](Attribute strAttr) {
43cf6a7c19SMahesh Ravishankar           return strAttr.cast<StringAttr>().getValue();
44cf6a7c19SMahesh Ravishankar         }));
45cf6a7c19SMahesh Ravishankar   }
46cf6a7c19SMahesh Ravishankar 
47cf6a7c19SMahesh Ravishankar   /// Return the iteration domain range.
getIterationDomain__anon8010582c0111::LinalgOpTilingInterface48cf6a7c19SMahesh Ravishankar   SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
492f637fe7SMahesh Ravishankar     OpBuilder::InsertionGuard g(b);
502f637fe7SMahesh Ravishankar     b.setInsertionPoint(op);
51cf6a7c19SMahesh Ravishankar     Location loc = op->getLoc();
52cf6a7c19SMahesh Ravishankar     LinalgOp linalgOp = cast<LinalgOp>(op);
53cf6a7c19SMahesh Ravishankar     auto allShapesSizes = linalgOp.createFlatListOfOperandDims(b, loc);
54cf6a7c19SMahesh Ravishankar     AffineMap map = linalgOp.getShapesToLoopsMap();
55cf6a7c19SMahesh Ravishankar     Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
56cf6a7c19SMahesh Ravishankar     Value one = b.create<arith::ConstantIndexOp>(loc, 1);
57cf6a7c19SMahesh Ravishankar     return llvm::to_vector(llvm::map_range(
58cf6a7c19SMahesh Ravishankar         applyMapToValues(b, loc, map, allShapesSizes), [&](Value v) {
59cf6a7c19SMahesh Ravishankar           return Range{zero, v, one};
60cf6a7c19SMahesh Ravishankar         }));
61cf6a7c19SMahesh Ravishankar   }
62cf6a7c19SMahesh Ravishankar 
63cf6a7c19SMahesh Ravishankar   // Instantiate the tiled implementation of the operation.
64cf6a7c19SMahesh Ravishankar   SmallVector<Operation *>
getTiledImplementation__anon8010582c0111::LinalgOpTilingInterface65cf6a7c19SMahesh Ravishankar   getTiledImplementation(Operation *op, OpBuilder &b, ValueRange dest,
66cf6a7c19SMahesh Ravishankar                          ArrayRef<OpFoldResult> offsets,
67cf6a7c19SMahesh Ravishankar                          ArrayRef<OpFoldResult> sizes,
68cf6a7c19SMahesh Ravishankar                          bool tileDestOperands) const {
69cf6a7c19SMahesh Ravishankar     // Leave the `sizeBounds` value empty. That is only needed when the `sizes`
70cf6a7c19SMahesh Ravishankar     // specified could lead to out of bounds accesses.
71cf6a7c19SMahesh Ravishankar     Location loc = op->getLoc();
72cf6a7c19SMahesh Ravishankar     LinalgOp linalgOp = cast<LinalgOp>(op);
73cf6a7c19SMahesh Ravishankar     SmallVector<Value> valuesToTile = linalgOp.getInputAndOutputOperands();
7481b62f7fSAlex Zinenko     SmallVector<Value> offsetValues =
7581b62f7fSAlex Zinenko         getValueOrCreateConstantIndexOp(b, loc, offsets);
76cf6a7c19SMahesh Ravishankar     SmallVector<Value, 4> tiledOperands = makeTiledShapes(
7781b62f7fSAlex Zinenko         b, loc, linalgOp, valuesToTile, offsetValues,
78cf6a7c19SMahesh Ravishankar         getValueOrCreateConstantIndexOp(b, loc, sizes), {}, true);
79cf6a7c19SMahesh Ravishankar 
80cf6a7c19SMahesh Ravishankar     SmallVector<Type> resultTensorTypes = llvm::to_vector(llvm::map_range(
81cf6a7c19SMahesh Ravishankar         linalgOp.getOutputTensorOperands(), [&](OpOperand *opOperand) {
82cf6a7c19SMahesh Ravishankar           return tiledOperands[opOperand->getOperandNumber()].getType();
83cf6a7c19SMahesh Ravishankar         }));
84cf6a7c19SMahesh Ravishankar 
85cf6a7c19SMahesh Ravishankar     Operation *tiledOp =
86cf6a7c19SMahesh Ravishankar         linalgOp.clone(b, loc, resultTensorTypes, tiledOperands);
8781b62f7fSAlex Zinenko     offsetIndices(b, cast<LinalgOp>(tiledOp), offsetValues);
88cf6a7c19SMahesh Ravishankar 
89cf6a7c19SMahesh Ravishankar     return {tiledOp};
90cf6a7c19SMahesh Ravishankar   }
91cf6a7c19SMahesh Ravishankar 
92cf6a7c19SMahesh Ravishankar   // Return the details of the output tile generated by the tiled
93cf6a7c19SMahesh Ravishankar   // implementation.
94cf6a7c19SMahesh Ravishankar   LogicalResult
getResultTilePosition__anon8010582c0111::LinalgOpTilingInterface95cf6a7c19SMahesh Ravishankar   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
96cf6a7c19SMahesh Ravishankar                         ArrayRef<OpFoldResult> offsets,
97cf6a7c19SMahesh Ravishankar                         ArrayRef<OpFoldResult> sizes,
98cf6a7c19SMahesh Ravishankar                         SmallVector<OpFoldResult> &resultOffsets,
99cf6a7c19SMahesh Ravishankar                         SmallVector<OpFoldResult> &resultSizes) const {
100cf6a7c19SMahesh Ravishankar     Location loc = op->getLoc();
101cf6a7c19SMahesh Ravishankar     LinalgOp linalgOp = cast<LinalgOp>(op);
102cf6a7c19SMahesh Ravishankar 
103cf6a7c19SMahesh Ravishankar     AffineExpr d0;
104cf6a7c19SMahesh Ravishankar     bindDims(b.getContext(), d0);
105cf6a7c19SMahesh Ravishankar 
106cf6a7c19SMahesh Ravishankar     auto fullyComposeAffineMapAndOperands = [](OpBuilder &builder, Location loc,
107cf6a7c19SMahesh Ravishankar                                                AffineExpr expr,
108cf6a7c19SMahesh Ravishankar                                                ValueRange operands) -> Value {
109cf6a7c19SMahesh Ravishankar       AffineMap map = AffineMap::inferFromExprList({expr}).front();
110cf6a7c19SMahesh Ravishankar       SmallVector<Value> normalizedOperands(operands.begin(), operands.end());
111cf6a7c19SMahesh Ravishankar       mlir::fullyComposeAffineMapAndOperands(&map, &normalizedOperands);
112cf6a7c19SMahesh Ravishankar       canonicalizeMapAndOperands(&map, &normalizedOperands);
113cf6a7c19SMahesh Ravishankar       return builder.createOrFold<AffineApplyOp>(loc, map, normalizedOperands);
114cf6a7c19SMahesh Ravishankar     };
115cf6a7c19SMahesh Ravishankar 
116cf6a7c19SMahesh Ravishankar     SmallVector<Value> sizeVals =
117cf6a7c19SMahesh Ravishankar         getValueOrCreateConstantIndexOp(b, loc, sizes);
118cf6a7c19SMahesh Ravishankar     SmallVector<Value> subShapeSizes =
119cf6a7c19SMahesh Ravishankar         llvm::to_vector(llvm::map_range(sizeVals, [&](Value v) {
120cf6a7c19SMahesh Ravishankar           return fullyComposeAffineMapAndOperands(b, loc, d0 - 1, v);
121cf6a7c19SMahesh Ravishankar         }));
122cf6a7c19SMahesh Ravishankar     OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber);
123cf6a7c19SMahesh Ravishankar     Value sliceOpResult =
124cf6a7c19SMahesh Ravishankar         makeTiledShape(b, loc, outOperand->get(), sizeVals,
125cf6a7c19SMahesh Ravishankar                        linalgOp.getTiedIndexingMap(outOperand),
126cf6a7c19SMahesh Ravishankar                        getValueOrCreateConstantIndexOp(b, loc, offsets),
127cf6a7c19SMahesh Ravishankar                        /*ubs*/ {}, subShapeSizes, true);
128cf6a7c19SMahesh Ravishankar     auto sliceOp = sliceOpResult.getDefiningOp<tensor::ExtractSliceOp>();
129cf6a7c19SMahesh Ravishankar     if (!sliceOp)
130cf6a7c19SMahesh Ravishankar       return failure();
131cf6a7c19SMahesh Ravishankar     resultOffsets = sliceOp.getMixedOffsets();
132cf6a7c19SMahesh Ravishankar     resultSizes = sliceOp.getMixedSizes();
133cf6a7c19SMahesh Ravishankar     return success();
134cf6a7c19SMahesh Ravishankar   }
1352f637fe7SMahesh Ravishankar 
generateResultTileValue__anon8010582c0111::LinalgOpTilingInterface1362f637fe7SMahesh Ravishankar   FailureOr<Value> generateResultTileValue(Operation *op, OpBuilder &b,
1372f637fe7SMahesh Ravishankar                                            unsigned resultNumber,
1382f637fe7SMahesh Ravishankar                                            ValueRange dest,
1392f637fe7SMahesh Ravishankar                                            ArrayRef<OpFoldResult> offsets,
1402f637fe7SMahesh Ravishankar                                            ArrayRef<OpFoldResult> sizes,
1412f637fe7SMahesh Ravishankar                                            bool tileDestOperands) const {
1422f637fe7SMahesh Ravishankar     auto linalgOp = cast<LinalgOp>(op);
1432f637fe7SMahesh Ravishankar 
1442f637fe7SMahesh Ravishankar     // Check that the indexing map used for the output is a projected
1452f637fe7SMahesh Ravishankar     // permutation. This could be relaxed with a more general approach that can
1462f637fe7SMahesh Ravishankar     // map the offsets and sizes from the result to iteration space tiles
1472f637fe7SMahesh Ravishankar     // (filling in full extent for dimensions not used to access the result).
1482f637fe7SMahesh Ravishankar     AffineMap indexingMap =
1492f637fe7SMahesh Ravishankar         linalgOp.getTiedIndexingMapForResult(op->getResult(resultNumber));
1502f637fe7SMahesh Ravishankar     if (!indexingMap.isProjectedPermutation()) {
1512f637fe7SMahesh Ravishankar       return op->emitOpError(
1522f637fe7SMahesh Ravishankar           "unhandled tiled implementation generation when result is not "
1532f637fe7SMahesh Ravishankar           "accessed using a permuted projection");
1542f637fe7SMahesh Ravishankar     }
1552f637fe7SMahesh Ravishankar 
1562f637fe7SMahesh Ravishankar     auto numLoops = linalgOp.getNumLoops();
1572f637fe7SMahesh Ravishankar     auto tilingInterfaceOp = cast<TilingInterface>(op);
1582f637fe7SMahesh Ravishankar     SmallVector<OpFoldResult> iterationTileOffsets(numLoops),
1592f637fe7SMahesh Ravishankar         iterationTileSizes(numLoops);
1602f637fe7SMahesh Ravishankar     if (!indexingMap.isPermutation()) {
1612f637fe7SMahesh Ravishankar       SmallVector<Range> iterationDomain =
1622f637fe7SMahesh Ravishankar           tilingInterfaceOp.getIterationDomain(b);
163*07628a94SAdrian Kuegel       for (const auto &range : llvm::enumerate(iterationDomain)) {
1642f637fe7SMahesh Ravishankar         iterationTileOffsets[range.index()] = range.value().offset;
1652f637fe7SMahesh Ravishankar         iterationTileSizes[range.index()] = range.value().size;
1662f637fe7SMahesh Ravishankar       }
1672f637fe7SMahesh Ravishankar     }
168*07628a94SAdrian Kuegel     for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
1692f637fe7SMahesh Ravishankar       unsigned dimPosition =
1702f637fe7SMahesh Ravishankar           resultExpr.value().cast<AffineDimExpr>().getPosition();
1712f637fe7SMahesh Ravishankar       iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
1722f637fe7SMahesh Ravishankar       iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
1732f637fe7SMahesh Ravishankar     }
1742f637fe7SMahesh Ravishankar 
1752f637fe7SMahesh Ravishankar     SmallVector<Operation *> tiledOp = tilingInterfaceOp.getTiledImplementation(
1762f637fe7SMahesh Ravishankar         b, dest, iterationTileOffsets, iterationTileSizes, tileDestOperands);
1772f637fe7SMahesh Ravishankar     if (tiledOp.size() != 1)
1782f637fe7SMahesh Ravishankar       return op->emitOpError("failed to generate tiled implementation");
1792f637fe7SMahesh Ravishankar 
1802f637fe7SMahesh Ravishankar     return tiledOp[0]->getResult(resultNumber);
1812f637fe7SMahesh Ravishankar   }
182cf6a7c19SMahesh Ravishankar };
183cf6a7c19SMahesh Ravishankar 
184cf6a7c19SMahesh Ravishankar } // namespace
185cf6a7c19SMahesh Ravishankar 
1862f637fe7SMahesh Ravishankar template <typename OpType>
registerOne(MLIRContext * ctx)1872f637fe7SMahesh Ravishankar static void registerOne(MLIRContext *ctx) {
188cf6a7c19SMahesh Ravishankar   OpType::template attachInterface<LinalgOpTilingInterface<OpType>>(*ctx);
189cf6a7c19SMahesh Ravishankar }
190cf6a7c19SMahesh Ravishankar 
191cf6a7c19SMahesh Ravishankar /// Variadic helper function.
1922f637fe7SMahesh Ravishankar template <typename... OpTypes>
registerAll(MLIRContext * ctx)1932f637fe7SMahesh Ravishankar static void registerAll(MLIRContext *ctx) {
194cf6a7c19SMahesh Ravishankar   // FIXME: In c++17 this can be simplified by using 'fold expressions'.
195cf6a7c19SMahesh Ravishankar   (void)std::initializer_list<int>{0, (registerOne<OpTypes>(ctx), 0)...};
196cf6a7c19SMahesh Ravishankar }
197cf6a7c19SMahesh Ravishankar 
198cf6a7c19SMahesh Ravishankar #define GET_OP_LIST
199cf6a7c19SMahesh Ravishankar 
registerTilingInterfaceExternalModels(DialectRegistry & registry)200cf6a7c19SMahesh Ravishankar void mlir::linalg::registerTilingInterfaceExternalModels(
201cf6a7c19SMahesh Ravishankar     DialectRegistry &registry) {
202cf6a7c19SMahesh Ravishankar   registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
203cf6a7c19SMahesh Ravishankar     registerOne<linalg::GenericOp>(ctx);
204cf6a7c19SMahesh Ravishankar     registerAll<
205cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
206cf6a7c19SMahesh Ravishankar         >(ctx);
207cf6a7c19SMahesh Ravishankar   });
208cf6a7c19SMahesh Ravishankar }
209