1 //===- TilingInterfaceImpl.cpp - Implementation of TilingInterface -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
10
11 #include "mlir/Dialect/Affine/IR/AffineOps.h"
12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
14 #include "mlir/Dialect/Linalg/IR/Linalg.h"
15 #include "mlir/Dialect/Linalg/Utils/Utils.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/Interfaces/TilingInterface.h"
18
19 using namespace mlir;
20 using namespace mlir::linalg;
21
22 namespace {
23
24 /// External model implementation of TilingInterface for LinalgOps. An external
25 /// model implementation is used for now till the use of `TilingInterface` is
26 /// on-par with the current Linalg tiling + fusion patterns. Once it is
27 /// maybe possible to move this into the op-definition (though there are
28 /// advantages to leaving it as an external model)
29 template <typename LinalgOpTy>
30 struct LinalgOpTilingInterface
31 : public TilingInterface::ExternalModel<LinalgOpTilingInterface<LinalgOpTy>,
32 LinalgOpTy> {
33 /// Return the destination operands.
getDestinationOperands__anon8010582c0111::LinalgOpTilingInterface34 SmallVector<Value> getDestinationOperands(Operation *op, OpBuilder &b) const {
35 return llvm::cast<LinalgOp>(op).getOutputOperands();
36 }
37
38 /// Return the loop iterator type.
getLoopIteratorTypes__anon8010582c0111::LinalgOpTilingInterface39 SmallVector<StringRef> getLoopIteratorTypes(Operation *op) const {
40 LinalgOpTy concreteOp = cast<LinalgOpTy>(op);
41 return llvm::to_vector(
42 llvm::map_range(concreteOp.iterator_types(), [](Attribute strAttr) {
43 return strAttr.cast<StringAttr>().getValue();
44 }));
45 }
46
47 /// Return the iteration domain range.
getIterationDomain__anon8010582c0111::LinalgOpTilingInterface48 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
49 OpBuilder::InsertionGuard g(b);
50 b.setInsertionPoint(op);
51 Location loc = op->getLoc();
52 LinalgOp linalgOp = cast<LinalgOp>(op);
53 auto allShapesSizes = linalgOp.createFlatListOfOperandDims(b, loc);
54 AffineMap map = linalgOp.getShapesToLoopsMap();
55 Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
56 Value one = b.create<arith::ConstantIndexOp>(loc, 1);
57 return llvm::to_vector(llvm::map_range(
58 applyMapToValues(b, loc, map, allShapesSizes), [&](Value v) {
59 return Range{zero, v, one};
60 }));
61 }
62
63 // Instantiate the tiled implementation of the operation.
64 SmallVector<Operation *>
getTiledImplementation__anon8010582c0111::LinalgOpTilingInterface65 getTiledImplementation(Operation *op, OpBuilder &b, ValueRange dest,
66 ArrayRef<OpFoldResult> offsets,
67 ArrayRef<OpFoldResult> sizes,
68 bool tileDestOperands) const {
69 // Leave the `sizeBounds` value empty. That is only needed when the `sizes`
70 // specified could lead to out of bounds accesses.
71 Location loc = op->getLoc();
72 LinalgOp linalgOp = cast<LinalgOp>(op);
73 SmallVector<Value> valuesToTile = linalgOp.getInputAndOutputOperands();
74 SmallVector<Value> offsetValues =
75 getValueOrCreateConstantIndexOp(b, loc, offsets);
76 SmallVector<Value, 4> tiledOperands = makeTiledShapes(
77 b, loc, linalgOp, valuesToTile, offsetValues,
78 getValueOrCreateConstantIndexOp(b, loc, sizes), {}, true);
79
80 SmallVector<Type> resultTensorTypes = llvm::to_vector(llvm::map_range(
81 linalgOp.getOutputTensorOperands(), [&](OpOperand *opOperand) {
82 return tiledOperands[opOperand->getOperandNumber()].getType();
83 }));
84
85 Operation *tiledOp =
86 linalgOp.clone(b, loc, resultTensorTypes, tiledOperands);
87 offsetIndices(b, cast<LinalgOp>(tiledOp), offsetValues);
88
89 return {tiledOp};
90 }
91
92 // Return the details of the output tile generated by the tiled
93 // implementation.
94 LogicalResult
getResultTilePosition__anon8010582c0111::LinalgOpTilingInterface95 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
96 ArrayRef<OpFoldResult> offsets,
97 ArrayRef<OpFoldResult> sizes,
98 SmallVector<OpFoldResult> &resultOffsets,
99 SmallVector<OpFoldResult> &resultSizes) const {
100 Location loc = op->getLoc();
101 LinalgOp linalgOp = cast<LinalgOp>(op);
102
103 AffineExpr d0;
104 bindDims(b.getContext(), d0);
105
106 auto fullyComposeAffineMapAndOperands = [](OpBuilder &builder, Location loc,
107 AffineExpr expr,
108 ValueRange operands) -> Value {
109 AffineMap map = AffineMap::inferFromExprList({expr}).front();
110 SmallVector<Value> normalizedOperands(operands.begin(), operands.end());
111 mlir::fullyComposeAffineMapAndOperands(&map, &normalizedOperands);
112 canonicalizeMapAndOperands(&map, &normalizedOperands);
113 return builder.createOrFold<AffineApplyOp>(loc, map, normalizedOperands);
114 };
115
116 SmallVector<Value> sizeVals =
117 getValueOrCreateConstantIndexOp(b, loc, sizes);
118 SmallVector<Value> subShapeSizes =
119 llvm::to_vector(llvm::map_range(sizeVals, [&](Value v) {
120 return fullyComposeAffineMapAndOperands(b, loc, d0 - 1, v);
121 }));
122 OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber);
123 Value sliceOpResult =
124 makeTiledShape(b, loc, outOperand->get(), sizeVals,
125 linalgOp.getTiedIndexingMap(outOperand),
126 getValueOrCreateConstantIndexOp(b, loc, offsets),
127 /*ubs*/ {}, subShapeSizes, true);
128 auto sliceOp = sliceOpResult.getDefiningOp<tensor::ExtractSliceOp>();
129 if (!sliceOp)
130 return failure();
131 resultOffsets = sliceOp.getMixedOffsets();
132 resultSizes = sliceOp.getMixedSizes();
133 return success();
134 }
135
generateResultTileValue__anon8010582c0111::LinalgOpTilingInterface136 FailureOr<Value> generateResultTileValue(Operation *op, OpBuilder &b,
137 unsigned resultNumber,
138 ValueRange dest,
139 ArrayRef<OpFoldResult> offsets,
140 ArrayRef<OpFoldResult> sizes,
141 bool tileDestOperands) const {
142 auto linalgOp = cast<LinalgOp>(op);
143
144 // Check that the indexing map used for the output is a projected
145 // permutation. This could be relaxed with a more general approach that can
146 // map the offsets and sizes from the result to iteration space tiles
147 // (filling in full extent for dimensions not used to access the result).
148 AffineMap indexingMap =
149 linalgOp.getTiedIndexingMapForResult(op->getResult(resultNumber));
150 if (!indexingMap.isProjectedPermutation()) {
151 return op->emitOpError(
152 "unhandled tiled implementation generation when result is not "
153 "accessed using a permuted projection");
154 }
155
156 auto numLoops = linalgOp.getNumLoops();
157 auto tilingInterfaceOp = cast<TilingInterface>(op);
158 SmallVector<OpFoldResult> iterationTileOffsets(numLoops),
159 iterationTileSizes(numLoops);
160 if (!indexingMap.isPermutation()) {
161 SmallVector<Range> iterationDomain =
162 tilingInterfaceOp.getIterationDomain(b);
163 for (const auto &range : llvm::enumerate(iterationDomain)) {
164 iterationTileOffsets[range.index()] = range.value().offset;
165 iterationTileSizes[range.index()] = range.value().size;
166 }
167 }
168 for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
169 unsigned dimPosition =
170 resultExpr.value().cast<AffineDimExpr>().getPosition();
171 iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
172 iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
173 }
174
175 SmallVector<Operation *> tiledOp = tilingInterfaceOp.getTiledImplementation(
176 b, dest, iterationTileOffsets, iterationTileSizes, tileDestOperands);
177 if (tiledOp.size() != 1)
178 return op->emitOpError("failed to generate tiled implementation");
179
180 return tiledOp[0]->getResult(resultNumber);
181 }
182 };
183
184 } // namespace
185
186 template <typename OpType>
registerOne(MLIRContext * ctx)187 static void registerOne(MLIRContext *ctx) {
188 OpType::template attachInterface<LinalgOpTilingInterface<OpType>>(*ctx);
189 }
190
191 /// Variadic helper function.
192 template <typename... OpTypes>
registerAll(MLIRContext * ctx)193 static void registerAll(MLIRContext *ctx) {
194 // FIXME: In c++17 this can be simplified by using 'fold expressions'.
195 (void)std::initializer_list<int>{0, (registerOne<OpTypes>(ctx), 0)...};
196 }
197
198 #define GET_OP_LIST
199
registerTilingInterfaceExternalModels(DialectRegistry & registry)200 void mlir::linalg::registerTilingInterfaceExternalModels(
201 DialectRegistry ®istry) {
202 registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
203 registerOne<linalg::GenericOp>(ctx);
204 registerAll<
205 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
206 >(ctx);
207 });
208 }
209