1 //===- InferTypeOpImpl.cpp - InferType Interface external models *- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h"
10 #include "mlir/Dialect/Affine/IR/AffineOps.h"
11 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
12 #include "mlir/Dialect/Tensor/IR/Tensor.h"
13 #include "mlir/Interfaces/InferTypeOpInterface.h"
14
15 using namespace mlir;
16 using namespace mlir::tensor;
17
18 /// Compute a map that for a given dimension of the expanded type gives the
19 /// dimension in the collapsed type it maps to. Essentially its the inverse of
20 /// the `reassocation` maps.
21 static llvm::DenseMap<int64_t, int64_t>
getExpandedDimToCollapsedDimMap(ArrayRef<AffineMap> reassociation)22 getExpandedDimToCollapsedDimMap(ArrayRef<AffineMap> reassociation) {
23 llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
24 for (const auto &map : enumerate(reassociation)) {
25 unsigned startPos =
26 map.value().getResults().front().cast<AffineDimExpr>().getPosition();
27 unsigned endPos =
28 map.value().getResults().back().cast<AffineDimExpr>().getPosition();
29 for (auto dim : llvm::seq_inclusive(startPos, endPos)) {
30 expandedDimToCollapsedDim[dim] = map.index();
31 }
32 }
33 return expandedDimToCollapsedDim;
34 }
35
36 /// For reshape op compute the shape at dimension `dimIndex` of the output in
37 /// terms of shape of the `src`, when the reshape op is a collapsing
38 /// operation. It is the product of the shape of the collapsed dimensions of the
39 /// `src`.
40 static OpFoldResult
getCollapsedOutputDimFromInputShape(OpBuilder & builder,Location loc,int64_t dimIndex,Value src,ArrayRef<AffineMap> reassociationMap)41 getCollapsedOutputDimFromInputShape(OpBuilder &builder, Location loc,
42 int64_t dimIndex, Value src,
43 ArrayRef<AffineMap> reassociationMap) {
44 AffineMap map = reassociationMap[dimIndex];
45 unsigned startPos =
46 map.getResults().front().cast<AffineDimExpr>().getPosition();
47 unsigned endPos = map.getResults().back().cast<AffineDimExpr>().getPosition();
48 AffineExpr expr;
49 SmallVector<Value, 2> dynamicDims;
50 for (auto dim : llvm::seq_inclusive(startPos, endPos)) {
51 dynamicDims.push_back(builder.createOrFold<tensor::DimOp>(loc, src, dim));
52 AffineExpr currExpr = builder.getAffineSymbolExpr(dim - startPos);
53 expr = (expr ? expr * currExpr : currExpr);
54 }
55 return applyMapToValues(builder, loc,
56 AffineMap::get(0, endPos - startPos + 1, expr),
57 dynamicDims)[0];
58 }
59
60 /// Given the `src` of a collapsing reshape op and its reassociation maps,
61 /// compute the shape of the result of the reshape.
getCollapsedOutputShapeFromInputShape(OpBuilder & builder,Location loc,Value src,ArrayRef<int64_t> dstStaticShape,ArrayRef<AffineMap> reassociation)62 static SmallVector<OpFoldResult, 4> getCollapsedOutputShapeFromInputShape(
63 OpBuilder &builder, Location loc, Value src,
64 ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) {
65 return llvm::to_vector<4>(llvm::map_range(
66 llvm::seq<int64_t>(0, dstStaticShape.size()), [&](int64_t dim) {
67 return getCollapsedOutputDimFromInputShape(builder, loc, dim, src,
68 reassociation);
69 }));
70 }
71
72 /// For an expanding reshape op, compute the value for a dimension of the output
73 /// from the shape of the input.
getExpandedOutputDimFromInputShape(OpBuilder & builder,Location loc,int64_t dimIndex,Value src,ArrayRef<int64_t> dstStaticShape,ArrayRef<AffineMap> reassociation,llvm::DenseMap<int64_t,int64_t> & expandedDimToCollapsedDim)74 static OpFoldResult getExpandedOutputDimFromInputShape(
75 OpBuilder &builder, Location loc, int64_t dimIndex, Value src,
76 ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation,
77 llvm::DenseMap<int64_t, int64_t> &expandedDimToCollapsedDim) {
78 if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) {
79 return builder.getI64IntegerAttr(dstStaticShape[dimIndex]);
80 }
81 unsigned sourceDimPos = expandedDimToCollapsedDim[dimIndex];
82 unsigned startPos = reassociation[sourceDimPos]
83 .getResults()
84 .front()
85 .cast<AffineDimExpr>()
86 .getPosition();
87 unsigned endPos = reassociation[sourceDimPos]
88 .getResults()
89 .back()
90 .cast<AffineDimExpr>()
91 .getPosition();
92 int64_t linearizedStaticDim = 1;
93 for (auto &d :
94 llvm::enumerate(dstStaticShape.slice(startPos, endPos - startPos + 1))) {
95 if (d.index() + startPos == static_cast<unsigned>(dimIndex))
96 continue;
97 assert(!ShapedType::isDynamic(d.value()) &&
98 "single dimension cannot be expanded into multiple dynamic "
99 "dimensions");
100 linearizedStaticDim *= d.value();
101 }
102 Value sourceDim = builder.create<tensor::DimOp>(loc, src, sourceDimPos);
103 return applyMapToValues(
104 builder, loc,
105 AffineMap::get(
106 0, 1, builder.getAffineSymbolExpr(0).floorDiv(linearizedStaticDim)),
107 sourceDim)[0];
108 }
109
110 /// Given the `src` of an expanding reshape op, the reassociation maps and the
111 /// result type, compute the shape of the result of the reshape.
getExpandedOutputShapeFromInputShape(OpBuilder & builder,Location loc,Value src,ArrayRef<int64_t> dstStaticShape,ArrayRef<AffineMap> reassociation)112 static SmallVector<OpFoldResult, 4> getExpandedOutputShapeFromInputShape(
113 OpBuilder &builder, Location loc, Value src,
114 ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) {
115 llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim =
116 getExpandedDimToCollapsedDimMap(reassociation);
117 return llvm::to_vector<4>(llvm::map_range(
118 llvm::seq<int64_t>(0, dstStaticShape.size()), [&](int64_t dim) {
119 return getExpandedOutputDimFromInputShape(builder, loc, dim, src,
120 dstStaticShape, reassociation,
121 expandedDimToCollapsedDim);
122 }));
123 }
124
125 static SmallVector<OpFoldResult, 4>
getReshapeOutputShapeFromInputShape(OpBuilder & builder,Location loc,Value src,ArrayRef<int64_t> dstStaticShape,ArrayRef<AffineMap> reassocation)126 getReshapeOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src,
127 ArrayRef<int64_t> dstStaticShape,
128 ArrayRef<AffineMap> reassocation) {
129 return dstStaticShape.size() >
130 static_cast<size_t>(src.getType().cast<ShapedType>().getRank())
131 ? getExpandedOutputShapeFromInputShape(
132 builder, loc, src, dstStaticShape, reassocation)
133 : getCollapsedOutputShapeFromInputShape(
134 builder, loc, src, dstStaticShape, reassocation);
135 }
136
137 /// Helper function to convert a vector of `OpFoldResult`s into a vector of
138 /// `Value`s.
getAsValues(OpBuilder & b,Location loc,ArrayRef<OpFoldResult> valueOrAttrVec)139 static SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
140 ArrayRef<OpFoldResult> valueOrAttrVec) {
141 return llvm::to_vector<4>(
142 llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
143 return getValueOrCreateConstantIndexOp(b, loc, value);
144 }));
145 }
146
147 template <typename OpTy>
148 struct ReifyExpandOrCollapseShapeOp
149 : public ReifyRankedShapedTypeOpInterface::ExternalModel<
150 ReifyExpandOrCollapseShapeOp<OpTy>, OpTy> {
151 LogicalResult
reifyResultShapesReifyExpandOrCollapseShapeOp152 reifyResultShapes(Operation *op, OpBuilder &b,
153 ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
154 auto loc = op->getLoc();
155 auto reshapeOp = cast<OpTy>(op);
156 auto resultShape = getReshapeOutputShapeFromInputShape(
157 b, loc, reshapeOp.getSrc(), reshapeOp.getResultType().getShape(),
158 reshapeOp.getReassociationMaps());
159 reifiedReturnShapes.push_back(getAsValues(b, loc, resultShape));
160 return success();
161 }
162 };
163
164 namespace {
165
166 struct ReifyPadOp
167 : public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyPadOp,
168 PadOp> {
169 LogicalResult
reifyResultShapes__anonf4e57e250411::ReifyPadOp170 reifyResultShapes(Operation *op, OpBuilder &b,
171 ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
172 auto padOp = cast<PadOp>(op);
173 Location loc = padOp.getLoc();
174 auto lowPad = padOp.getMixedLowPad();
175 auto highPad = padOp.getMixedHighPad();
176 SmallVector<Value> shapes;
177 for (auto dim : llvm::seq<int64_t>(0, padOp.getSourceType().getRank())) {
178 // Shape along each dimension is source dim + low pad + high pad.
179 SmallVector<Value> mapOperands;
180 mapOperands.push_back(
181 b.createOrFold<tensor::DimOp>(loc, padOp.getSource(), dim));
182 AffineExpr expr = b.getAffineDimExpr(0);
183 unsigned numSymbols = 0;
184 auto addOpFoldResult = [&](OpFoldResult valueOrAttr) {
185 if (Value v = valueOrAttr.dyn_cast<Value>()) {
186 expr = expr + b.getAffineSymbolExpr(numSymbols++);
187 mapOperands.push_back(v);
188 return;
189 }
190 int64_t staticValue =
191 valueOrAttr.get<Attribute>().cast<IntegerAttr>().getInt();
192 expr = expr + staticValue;
193 };
194 addOpFoldResult(lowPad[dim]);
195 addOpFoldResult(highPad[dim]);
196 shapes.push_back(applyMapToValues(
197 b, loc, AffineMap::get(1, numSymbols, expr), mapOperands)[0]);
198 }
199 reifiedReturnShapes.emplace_back(std::move(shapes));
200 return success();
201 }
202 };
203
204 } // namespace
205
registerInferTypeOpInterfaceExternalModels(DialectRegistry & registry)206 void mlir::tensor::registerInferTypeOpInterfaceExternalModels(
207 DialectRegistry ®istry) {
208 registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
209 ExpandShapeOp::attachInterface<
210 ReifyExpandOrCollapseShapeOp<tensor::ExpandShapeOp>>(*ctx);
211 CollapseShapeOp::attachInterface<
212 ReifyExpandOrCollapseShapeOp<tensor::CollapseShapeOp>>(*ctx);
213 PadOp::attachInterface<ReifyPadOp>(*ctx);
214 });
215 }
216