1*cf6a7c19SMahesh Ravishankar //===- TilingInterfaceImpl.cpp - Implementation of TilingInterface -------===//
2*cf6a7c19SMahesh Ravishankar //
3*cf6a7c19SMahesh Ravishankar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*cf6a7c19SMahesh Ravishankar // See https://llvm.org/LICENSE.txt for license information.
5*cf6a7c19SMahesh Ravishankar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*cf6a7c19SMahesh Ravishankar //
7*cf6a7c19SMahesh Ravishankar //===----------------------------------------------------------------------===//
8*cf6a7c19SMahesh Ravishankar 
9*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
10*cf6a7c19SMahesh Ravishankar 
11*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Affine/IR/AffineOps.h"
12*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
14*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Linalg/IR/Linalg.h"
15*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Linalg/Utils/Utils.h"
16*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Tensor/IR/Tensor.h"
17*cf6a7c19SMahesh Ravishankar #include "mlir/Interfaces/TilingInterface.h"
18*cf6a7c19SMahesh Ravishankar 
19*cf6a7c19SMahesh Ravishankar using namespace mlir;
20*cf6a7c19SMahesh Ravishankar using namespace mlir::linalg;
21*cf6a7c19SMahesh Ravishankar 
22*cf6a7c19SMahesh Ravishankar namespace {
23*cf6a7c19SMahesh Ravishankar 
24*cf6a7c19SMahesh Ravishankar /// External model implementation of TilingInterface for LinalgOps. An external
25*cf6a7c19SMahesh Ravishankar /// model implementation is used for now till the use of `TilingInterface` is
26*cf6a7c19SMahesh Ravishankar /// on-par with the current Linalg tiling + fusion patterns. Once it is
27*cf6a7c19SMahesh Ravishankar /// maybe possible to move this into the op-definition (though there are
28*cf6a7c19SMahesh Ravishankar /// advantages to leaving it as an external model)
29*cf6a7c19SMahesh Ravishankar template <typename LinalgOpTy>
30*cf6a7c19SMahesh Ravishankar struct LinalgOpTilingInterface
31*cf6a7c19SMahesh Ravishankar     : public TilingInterface::ExternalModel<LinalgOpTilingInterface<LinalgOpTy>,
32*cf6a7c19SMahesh Ravishankar                                             LinalgOpTy> {
33*cf6a7c19SMahesh Ravishankar 
34*cf6a7c19SMahesh Ravishankar   /// Return the destination operands.
35*cf6a7c19SMahesh Ravishankar   SmallVector<Value> getDestinationOperands(Operation *op, OpBuilder &b) const {
36*cf6a7c19SMahesh Ravishankar     return llvm::cast<LinalgOp>(op).getOutputOperands();
37*cf6a7c19SMahesh Ravishankar   }
38*cf6a7c19SMahesh Ravishankar 
39*cf6a7c19SMahesh Ravishankar   /// Return the loop iterator type.
40*cf6a7c19SMahesh Ravishankar   SmallVector<StringRef> getLoopIteratorTypes(Operation *op) const {
41*cf6a7c19SMahesh Ravishankar     LinalgOpTy concreteOp = cast<LinalgOpTy>(op);
42*cf6a7c19SMahesh Ravishankar     return llvm::to_vector(
43*cf6a7c19SMahesh Ravishankar         llvm::map_range(concreteOp.iterator_types(), [](Attribute strAttr) {
44*cf6a7c19SMahesh Ravishankar           return strAttr.cast<StringAttr>().getValue();
45*cf6a7c19SMahesh Ravishankar         }));
46*cf6a7c19SMahesh Ravishankar   }
47*cf6a7c19SMahesh Ravishankar 
48*cf6a7c19SMahesh Ravishankar   /// Return the iteration domain range.
49*cf6a7c19SMahesh Ravishankar   SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
50*cf6a7c19SMahesh Ravishankar     Location loc = op->getLoc();
51*cf6a7c19SMahesh Ravishankar     LinalgOp linalgOp = cast<LinalgOp>(op);
52*cf6a7c19SMahesh Ravishankar     auto allShapesSizes = linalgOp.createFlatListOfOperandDims(b, loc);
53*cf6a7c19SMahesh Ravishankar     AffineMap map = linalgOp.getShapesToLoopsMap();
54*cf6a7c19SMahesh Ravishankar     Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
55*cf6a7c19SMahesh Ravishankar     Value one = b.create<arith::ConstantIndexOp>(loc, 1);
56*cf6a7c19SMahesh Ravishankar     return llvm::to_vector(llvm::map_range(
57*cf6a7c19SMahesh Ravishankar         applyMapToValues(b, loc, map, allShapesSizes), [&](Value v) {
58*cf6a7c19SMahesh Ravishankar           return Range{zero, v, one};
59*cf6a7c19SMahesh Ravishankar         }));
60*cf6a7c19SMahesh Ravishankar   }
61*cf6a7c19SMahesh Ravishankar 
62*cf6a7c19SMahesh Ravishankar   // Instantiate the tiled implementation of the operation.
63*cf6a7c19SMahesh Ravishankar   SmallVector<Operation *>
64*cf6a7c19SMahesh Ravishankar   getTiledImplementation(Operation *op, OpBuilder &b, ValueRange dest,
65*cf6a7c19SMahesh Ravishankar                          ArrayRef<OpFoldResult> offsets,
66*cf6a7c19SMahesh Ravishankar                          ArrayRef<OpFoldResult> sizes,
67*cf6a7c19SMahesh Ravishankar                          bool tileDestOperands) const {
68*cf6a7c19SMahesh Ravishankar     // Leave the `sizeBounds` value empty. That is only needed when the `sizes`
69*cf6a7c19SMahesh Ravishankar     // specified could lead to out of bounds accesses.
70*cf6a7c19SMahesh Ravishankar     Location loc = op->getLoc();
71*cf6a7c19SMahesh Ravishankar     LinalgOp linalgOp = cast<LinalgOp>(op);
72*cf6a7c19SMahesh Ravishankar     SmallVector<Value> valuesToTile = linalgOp.getInputAndOutputOperands();
73*cf6a7c19SMahesh Ravishankar     SmallVector<Value, 4> tiledOperands = makeTiledShapes(
74*cf6a7c19SMahesh Ravishankar         b, loc, linalgOp, valuesToTile,
75*cf6a7c19SMahesh Ravishankar         getValueOrCreateConstantIndexOp(b, loc, offsets),
76*cf6a7c19SMahesh Ravishankar         getValueOrCreateConstantIndexOp(b, loc, sizes), {}, true);
77*cf6a7c19SMahesh Ravishankar 
78*cf6a7c19SMahesh Ravishankar     SmallVector<Type> resultTensorTypes = llvm::to_vector(llvm::map_range(
79*cf6a7c19SMahesh Ravishankar         linalgOp.getOutputTensorOperands(), [&](OpOperand *opOperand) {
80*cf6a7c19SMahesh Ravishankar           return tiledOperands[opOperand->getOperandNumber()].getType();
81*cf6a7c19SMahesh Ravishankar         }));
82*cf6a7c19SMahesh Ravishankar 
83*cf6a7c19SMahesh Ravishankar     Operation *tiledOp =
84*cf6a7c19SMahesh Ravishankar         linalgOp.clone(b, loc, resultTensorTypes, tiledOperands);
85*cf6a7c19SMahesh Ravishankar 
86*cf6a7c19SMahesh Ravishankar     return {tiledOp};
87*cf6a7c19SMahesh Ravishankar   }
88*cf6a7c19SMahesh Ravishankar 
89*cf6a7c19SMahesh Ravishankar   // Return the details of the output tile generated by the tiled
90*cf6a7c19SMahesh Ravishankar   // implementation.
91*cf6a7c19SMahesh Ravishankar   LogicalResult
92*cf6a7c19SMahesh Ravishankar   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
93*cf6a7c19SMahesh Ravishankar                         ArrayRef<OpFoldResult> offsets,
94*cf6a7c19SMahesh Ravishankar                         ArrayRef<OpFoldResult> sizes,
95*cf6a7c19SMahesh Ravishankar                         SmallVector<OpFoldResult> &resultOffsets,
96*cf6a7c19SMahesh Ravishankar                         SmallVector<OpFoldResult> &resultSizes) const {
97*cf6a7c19SMahesh Ravishankar     Location loc = op->getLoc();
98*cf6a7c19SMahesh Ravishankar     LinalgOp linalgOp = cast<LinalgOp>(op);
99*cf6a7c19SMahesh Ravishankar 
100*cf6a7c19SMahesh Ravishankar     AffineExpr d0;
101*cf6a7c19SMahesh Ravishankar     bindDims(b.getContext(), d0);
102*cf6a7c19SMahesh Ravishankar 
103*cf6a7c19SMahesh Ravishankar     auto fullyComposeAffineMapAndOperands = [](OpBuilder &builder, Location loc,
104*cf6a7c19SMahesh Ravishankar                                                AffineExpr expr,
105*cf6a7c19SMahesh Ravishankar                                                ValueRange operands) -> Value {
106*cf6a7c19SMahesh Ravishankar       AffineMap map = AffineMap::inferFromExprList({expr}).front();
107*cf6a7c19SMahesh Ravishankar       SmallVector<Value> normalizedOperands(operands.begin(), operands.end());
108*cf6a7c19SMahesh Ravishankar       mlir::fullyComposeAffineMapAndOperands(&map, &normalizedOperands);
109*cf6a7c19SMahesh Ravishankar       canonicalizeMapAndOperands(&map, &normalizedOperands);
110*cf6a7c19SMahesh Ravishankar       return builder.createOrFold<AffineApplyOp>(loc, map, normalizedOperands);
111*cf6a7c19SMahesh Ravishankar     };
112*cf6a7c19SMahesh Ravishankar 
113*cf6a7c19SMahesh Ravishankar     SmallVector<Value> sizeVals =
114*cf6a7c19SMahesh Ravishankar         getValueOrCreateConstantIndexOp(b, loc, sizes);
115*cf6a7c19SMahesh Ravishankar     SmallVector<Value> subShapeSizes =
116*cf6a7c19SMahesh Ravishankar         llvm::to_vector(llvm::map_range(sizeVals, [&](Value v) {
117*cf6a7c19SMahesh Ravishankar           return fullyComposeAffineMapAndOperands(b, loc, d0 - 1, v);
118*cf6a7c19SMahesh Ravishankar         }));
119*cf6a7c19SMahesh Ravishankar     OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber);
120*cf6a7c19SMahesh Ravishankar     Value sliceOpResult =
121*cf6a7c19SMahesh Ravishankar         makeTiledShape(b, loc, outOperand->get(), sizeVals,
122*cf6a7c19SMahesh Ravishankar                        linalgOp.getTiedIndexingMap(outOperand),
123*cf6a7c19SMahesh Ravishankar                        getValueOrCreateConstantIndexOp(b, loc, offsets),
124*cf6a7c19SMahesh Ravishankar                        /*ubs*/ {}, subShapeSizes, true);
125*cf6a7c19SMahesh Ravishankar     auto sliceOp = sliceOpResult.getDefiningOp<tensor::ExtractSliceOp>();
126*cf6a7c19SMahesh Ravishankar     if (!sliceOp)
127*cf6a7c19SMahesh Ravishankar       return failure();
128*cf6a7c19SMahesh Ravishankar     resultOffsets = sliceOp.getMixedOffsets();
129*cf6a7c19SMahesh Ravishankar     resultSizes = sliceOp.getMixedSizes();
130*cf6a7c19SMahesh Ravishankar     return success();
131*cf6a7c19SMahesh Ravishankar   }
132*cf6a7c19SMahesh Ravishankar };
133*cf6a7c19SMahesh Ravishankar 
134*cf6a7c19SMahesh Ravishankar } // namespace
135*cf6a7c19SMahesh Ravishankar 
136*cf6a7c19SMahesh Ravishankar template <typename OpType> static void registerOne(MLIRContext *ctx) {
137*cf6a7c19SMahesh Ravishankar   OpType::template attachInterface<LinalgOpTilingInterface<OpType>>(*ctx);
138*cf6a7c19SMahesh Ravishankar }
139*cf6a7c19SMahesh Ravishankar 
140*cf6a7c19SMahesh Ravishankar /// Variadic helper function.
141*cf6a7c19SMahesh Ravishankar template <typename... OpTypes> static void registerAll(MLIRContext *ctx) {
142*cf6a7c19SMahesh Ravishankar   // FIXME: In c++17 this can be simplified by using 'fold expressions'.
143*cf6a7c19SMahesh Ravishankar   (void)std::initializer_list<int>{0, (registerOne<OpTypes>(ctx), 0)...};
144*cf6a7c19SMahesh Ravishankar }
145*cf6a7c19SMahesh Ravishankar 
146*cf6a7c19SMahesh Ravishankar #define GET_OP_LIST
147*cf6a7c19SMahesh Ravishankar 
148*cf6a7c19SMahesh Ravishankar void mlir::linalg::registerTilingInterfaceExternalModels(
149*cf6a7c19SMahesh Ravishankar     DialectRegistry &registry) {
150*cf6a7c19SMahesh Ravishankar   registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
151*cf6a7c19SMahesh Ravishankar     registerOne<linalg::GenericOp>(ctx);
152*cf6a7c19SMahesh Ravishankar     registerAll<
153*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
154*cf6a7c19SMahesh Ravishankar         >(ctx);
155*cf6a7c19SMahesh Ravishankar   });
156*cf6a7c19SMahesh Ravishankar }
157