1 //===- TilingInterfaceImpl.cpp - Implementation of TilingInterface -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
10 
11 #include "mlir/Dialect/Affine/IR/AffineOps.h"
12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
14 #include "mlir/Dialect/Linalg/IR/Linalg.h"
15 #include "mlir/Dialect/Linalg/Utils/Utils.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/Interfaces/TilingInterface.h"
18 
19 using namespace mlir;
20 using namespace mlir::linalg;
21 
22 namespace {
23 
24 /// External model implementation of TilingInterface for LinalgOps. An external
25 /// model implementation is used for now till the use of `TilingInterface` is
26 /// on-par with the current Linalg tiling + fusion patterns. Once it is
27 /// maybe possible to move this into the op-definition (though there are
28 /// advantages to leaving it as an external model)
29 template <typename LinalgOpTy>
30 struct LinalgOpTilingInterface
31     : public TilingInterface::ExternalModel<LinalgOpTilingInterface<LinalgOpTy>,
32                                             LinalgOpTy> {
33 
34   /// Return the destination operands.
35   SmallVector<Value> getDestinationOperands(Operation *op, OpBuilder &b) const {
36     return llvm::cast<LinalgOp>(op).getOutputOperands();
37   }
38 
39   /// Return the loop iterator type.
40   SmallVector<StringRef> getLoopIteratorTypes(Operation *op) const {
41     LinalgOpTy concreteOp = cast<LinalgOpTy>(op);
42     return llvm::to_vector(
43         llvm::map_range(concreteOp.iterator_types(), [](Attribute strAttr) {
44           return strAttr.cast<StringAttr>().getValue();
45         }));
46   }
47 
48   /// Return the iteration domain range.
49   SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
50     Location loc = op->getLoc();
51     LinalgOp linalgOp = cast<LinalgOp>(op);
52     auto allShapesSizes = linalgOp.createFlatListOfOperandDims(b, loc);
53     AffineMap map = linalgOp.getShapesToLoopsMap();
54     Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
55     Value one = b.create<arith::ConstantIndexOp>(loc, 1);
56     return llvm::to_vector(llvm::map_range(
57         applyMapToValues(b, loc, map, allShapesSizes), [&](Value v) {
58           return Range{zero, v, one};
59         }));
60   }
61 
62   // Instantiate the tiled implementation of the operation.
63   SmallVector<Operation *>
64   getTiledImplementation(Operation *op, OpBuilder &b, ValueRange dest,
65                          ArrayRef<OpFoldResult> offsets,
66                          ArrayRef<OpFoldResult> sizes,
67                          bool tileDestOperands) const {
68     // Leave the `sizeBounds` value empty. That is only needed when the `sizes`
69     // specified could lead to out of bounds accesses.
70     Location loc = op->getLoc();
71     LinalgOp linalgOp = cast<LinalgOp>(op);
72     SmallVector<Value> valuesToTile = linalgOp.getInputAndOutputOperands();
73     SmallVector<Value, 4> tiledOperands = makeTiledShapes(
74         b, loc, linalgOp, valuesToTile,
75         getValueOrCreateConstantIndexOp(b, loc, offsets),
76         getValueOrCreateConstantIndexOp(b, loc, sizes), {}, true);
77 
78     SmallVector<Type> resultTensorTypes = llvm::to_vector(llvm::map_range(
79         linalgOp.getOutputTensorOperands(), [&](OpOperand *opOperand) {
80           return tiledOperands[opOperand->getOperandNumber()].getType();
81         }));
82 
83     Operation *tiledOp =
84         linalgOp.clone(b, loc, resultTensorTypes, tiledOperands);
85 
86     return {tiledOp};
87   }
88 
89   // Return the details of the output tile generated by the tiled
90   // implementation.
91   LogicalResult
92   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
93                         ArrayRef<OpFoldResult> offsets,
94                         ArrayRef<OpFoldResult> sizes,
95                         SmallVector<OpFoldResult> &resultOffsets,
96                         SmallVector<OpFoldResult> &resultSizes) const {
97     Location loc = op->getLoc();
98     LinalgOp linalgOp = cast<LinalgOp>(op);
99 
100     AffineExpr d0;
101     bindDims(b.getContext(), d0);
102 
103     auto fullyComposeAffineMapAndOperands = [](OpBuilder &builder, Location loc,
104                                                AffineExpr expr,
105                                                ValueRange operands) -> Value {
106       AffineMap map = AffineMap::inferFromExprList({expr}).front();
107       SmallVector<Value> normalizedOperands(operands.begin(), operands.end());
108       mlir::fullyComposeAffineMapAndOperands(&map, &normalizedOperands);
109       canonicalizeMapAndOperands(&map, &normalizedOperands);
110       return builder.createOrFold<AffineApplyOp>(loc, map, normalizedOperands);
111     };
112 
113     SmallVector<Value> sizeVals =
114         getValueOrCreateConstantIndexOp(b, loc, sizes);
115     SmallVector<Value> subShapeSizes =
116         llvm::to_vector(llvm::map_range(sizeVals, [&](Value v) {
117           return fullyComposeAffineMapAndOperands(b, loc, d0 - 1, v);
118         }));
119     OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber);
120     Value sliceOpResult =
121         makeTiledShape(b, loc, outOperand->get(), sizeVals,
122                        linalgOp.getTiedIndexingMap(outOperand),
123                        getValueOrCreateConstantIndexOp(b, loc, offsets),
124                        /*ubs*/ {}, subShapeSizes, true);
125     auto sliceOp = sliceOpResult.getDefiningOp<tensor::ExtractSliceOp>();
126     if (!sliceOp)
127       return failure();
128     resultOffsets = sliceOp.getMixedOffsets();
129     resultSizes = sliceOp.getMixedSizes();
130     return success();
131   }
132 };
133 
134 } // namespace
135 
136 template <typename OpType> static void registerOne(MLIRContext *ctx) {
137   OpType::template attachInterface<LinalgOpTilingInterface<OpType>>(*ctx);
138 }
139 
140 /// Variadic helper function.
141 template <typename... OpTypes> static void registerAll(MLIRContext *ctx) {
142   // FIXME: In c++17 this can be simplified by using 'fold expressions'.
143   (void)std::initializer_list<int>{0, (registerOne<OpTypes>(ctx), 0)...};
144 }
145 
146 #define GET_OP_LIST
147 
148 void mlir::linalg::registerTilingInterfaceExternalModels(
149     DialectRegistry &registry) {
150   registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
151     registerOne<linalg::GenericOp>(ctx);
152     registerAll<
153 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
154         >(ctx);
155   });
156 }
157