1 //===- TilingInterfaceImpl.cpp - Implementation of TilingInterface -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
10 
11 #include "mlir/Dialect/Affine/IR/AffineOps.h"
12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
14 #include "mlir/Dialect/Linalg/IR/Linalg.h"
15 #include "mlir/Dialect/Linalg/Utils/Utils.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/Interfaces/TilingInterface.h"
18 
19 using namespace mlir;
20 using namespace mlir::linalg;
21 
22 namespace {
23 
24 /// External model implementation of TilingInterface for LinalgOps. An external
25 /// model implementation is used for now till the use of `TilingInterface` is
26 /// on-par with the current Linalg tiling + fusion patterns. Once it is
27 /// maybe possible to move this into the op-definition (though there are
28 /// advantages to leaving it as an external model)
29 template <typename LinalgOpTy>
30 struct LinalgOpTilingInterface
31     : public TilingInterface::ExternalModel<LinalgOpTilingInterface<LinalgOpTy>,
32                                             LinalgOpTy> {
33   /// Return the destination operands.
34   SmallVector<Value> getDestinationOperands(Operation *op, OpBuilder &b) const {
35     return llvm::cast<LinalgOp>(op).getOutputOperands();
36   }
37 
38   /// Return the loop iterator type.
39   SmallVector<StringRef> getLoopIteratorTypes(Operation *op) const {
40     LinalgOpTy concreteOp = cast<LinalgOpTy>(op);
41     return llvm::to_vector(
42         llvm::map_range(concreteOp.iterator_types(), [](Attribute strAttr) {
43           return strAttr.cast<StringAttr>().getValue();
44         }));
45   }
46 
47   /// Return the iteration domain range.
48   SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
49     OpBuilder::InsertionGuard g(b);
50     b.setInsertionPoint(op);
51     Location loc = op->getLoc();
52     LinalgOp linalgOp = cast<LinalgOp>(op);
53     auto allShapesSizes = linalgOp.createFlatListOfOperandDims(b, loc);
54     AffineMap map = linalgOp.getShapesToLoopsMap();
55     Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
56     Value one = b.create<arith::ConstantIndexOp>(loc, 1);
57     return llvm::to_vector(llvm::map_range(
58         applyMapToValues(b, loc, map, allShapesSizes), [&](Value v) {
59           return Range{zero, v, one};
60         }));
61   }
62 
63   // Instantiate the tiled implementation of the operation.
64   SmallVector<Operation *>
65   getTiledImplementation(Operation *op, OpBuilder &b, ValueRange dest,
66                          ArrayRef<OpFoldResult> offsets,
67                          ArrayRef<OpFoldResult> sizes,
68                          bool tileDestOperands) const {
69     // Leave the `sizeBounds` value empty. That is only needed when the `sizes`
70     // specified could lead to out of bounds accesses.
71     Location loc = op->getLoc();
72     LinalgOp linalgOp = cast<LinalgOp>(op);
73     SmallVector<Value> valuesToTile = linalgOp.getInputAndOutputOperands();
74     SmallVector<Value, 4> tiledOperands = makeTiledShapes(
75         b, loc, linalgOp, valuesToTile,
76         getValueOrCreateConstantIndexOp(b, loc, offsets),
77         getValueOrCreateConstantIndexOp(b, loc, sizes), {}, true);
78 
79     SmallVector<Type> resultTensorTypes = llvm::to_vector(llvm::map_range(
80         linalgOp.getOutputTensorOperands(), [&](OpOperand *opOperand) {
81           return tiledOperands[opOperand->getOperandNumber()].getType();
82         }));
83 
84     Operation *tiledOp =
85         linalgOp.clone(b, loc, resultTensorTypes, tiledOperands);
86 
87     return {tiledOp};
88   }
89 
90   // Return the details of the output tile generated by the tiled
91   // implementation.
92   LogicalResult
93   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
94                         ArrayRef<OpFoldResult> offsets,
95                         ArrayRef<OpFoldResult> sizes,
96                         SmallVector<OpFoldResult> &resultOffsets,
97                         SmallVector<OpFoldResult> &resultSizes) const {
98     Location loc = op->getLoc();
99     LinalgOp linalgOp = cast<LinalgOp>(op);
100 
101     AffineExpr d0;
102     bindDims(b.getContext(), d0);
103 
104     auto fullyComposeAffineMapAndOperands = [](OpBuilder &builder, Location loc,
105                                                AffineExpr expr,
106                                                ValueRange operands) -> Value {
107       AffineMap map = AffineMap::inferFromExprList({expr}).front();
108       SmallVector<Value> normalizedOperands(operands.begin(), operands.end());
109       mlir::fullyComposeAffineMapAndOperands(&map, &normalizedOperands);
110       canonicalizeMapAndOperands(&map, &normalizedOperands);
111       return builder.createOrFold<AffineApplyOp>(loc, map, normalizedOperands);
112     };
113 
114     SmallVector<Value> sizeVals =
115         getValueOrCreateConstantIndexOp(b, loc, sizes);
116     SmallVector<Value> subShapeSizes =
117         llvm::to_vector(llvm::map_range(sizeVals, [&](Value v) {
118           return fullyComposeAffineMapAndOperands(b, loc, d0 - 1, v);
119         }));
120     OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber);
121     Value sliceOpResult =
122         makeTiledShape(b, loc, outOperand->get(), sizeVals,
123                        linalgOp.getTiedIndexingMap(outOperand),
124                        getValueOrCreateConstantIndexOp(b, loc, offsets),
125                        /*ubs*/ {}, subShapeSizes, true);
126     auto sliceOp = sliceOpResult.getDefiningOp<tensor::ExtractSliceOp>();
127     if (!sliceOp)
128       return failure();
129     resultOffsets = sliceOp.getMixedOffsets();
130     resultSizes = sliceOp.getMixedSizes();
131     return success();
132   }
133 
134   FailureOr<Value> generateResultTileValue(Operation *op, OpBuilder &b,
135                                            unsigned resultNumber,
136                                            ValueRange dest,
137                                            ArrayRef<OpFoldResult> offsets,
138                                            ArrayRef<OpFoldResult> sizes,
139                                            bool tileDestOperands) const {
140     auto linalgOp = cast<LinalgOp>(op);
141 
142     // Check that the indexing map used for the output is a projected
143     // permutation. This could be relaxed with a more general approach that can
144     // map the offsets and sizes from the result to iteration space tiles
145     // (filling in full extent for dimensions not used to access the result).
146     AffineMap indexingMap =
147         linalgOp.getTiedIndexingMapForResult(op->getResult(resultNumber));
148     if (!indexingMap.isProjectedPermutation()) {
149       return op->emitOpError(
150           "unhandled tiled implementation generation when result is not "
151           "accessed using a permuted projection");
152     }
153 
154     auto numLoops = linalgOp.getNumLoops();
155     auto tilingInterfaceOp = cast<TilingInterface>(op);
156     SmallVector<OpFoldResult> iterationTileOffsets(numLoops),
157         iterationTileSizes(numLoops);
158     if (!indexingMap.isPermutation()) {
159       SmallVector<Range> iterationDomain =
160           tilingInterfaceOp.getIterationDomain(b);
161       for (auto range : llvm::enumerate(iterationDomain)) {
162         iterationTileOffsets[range.index()] = range.value().offset;
163         iterationTileSizes[range.index()] = range.value().size;
164       }
165     }
166     for (auto resultExpr : llvm::enumerate(indexingMap.getResults())) {
167       unsigned dimPosition =
168           resultExpr.value().cast<AffineDimExpr>().getPosition();
169       iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
170       iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
171     }
172 
173     SmallVector<Operation *> tiledOp = tilingInterfaceOp.getTiledImplementation(
174         b, dest, iterationTileOffsets, iterationTileSizes, tileDestOperands);
175     if (tiledOp.size() != 1)
176       return op->emitOpError("failed to generate tiled implementation");
177 
178     return tiledOp[0]->getResult(resultNumber);
179   }
180 };
181 
182 } // namespace
183 
184 template <typename OpType>
185 static void registerOne(MLIRContext *ctx) {
186   OpType::template attachInterface<LinalgOpTilingInterface<OpType>>(*ctx);
187 }
188 
189 /// Variadic helper function.
190 template <typename... OpTypes>
191 static void registerAll(MLIRContext *ctx) {
192   // FIXME: In c++17 this can be simplified by using 'fold expressions'.
193   (void)std::initializer_list<int>{0, (registerOne<OpTypes>(ctx), 0)...};
194 }
195 
196 #define GET_OP_LIST
197 
198 void mlir::linalg::registerTilingInterfaceExternalModels(
199     DialectRegistry &registry) {
200   registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
201     registerOne<linalg::GenericOp>(ctx);
202     registerAll<
203 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
204         >(ctx);
205   });
206 }
207