1 //===- InferTypeOpImpl.cpp - InferType Interface external models *- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h"
10 #include "mlir/Dialect/Affine/IR/AffineOps.h"
11 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
12 #include "mlir/Dialect/Tensor/IR/Tensor.h"
13 #include "mlir/Interfaces/InferTypeOpInterface.h"
14 
15 using namespace mlir;
16 using namespace mlir::tensor;
17 
18 /// Compute a map that for a given dimension of the expanded type gives the
19 /// dimension in the collapsed type it maps to. Essentially its the inverse of
20 /// the `reassocation` maps.
21 static llvm::DenseMap<int64_t, int64_t>
getExpandedDimToCollapsedDimMap(ArrayRef<AffineMap> reassociation)22 getExpandedDimToCollapsedDimMap(ArrayRef<AffineMap> reassociation) {
23   llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
24   for (const auto &map : enumerate(reassociation)) {
25     unsigned startPos =
26         map.value().getResults().front().cast<AffineDimExpr>().getPosition();
27     unsigned endPos =
28         map.value().getResults().back().cast<AffineDimExpr>().getPosition();
29     for (auto dim : llvm::seq_inclusive(startPos, endPos)) {
30       expandedDimToCollapsedDim[dim] = map.index();
31     }
32   }
33   return expandedDimToCollapsedDim;
34 }
35 
36 /// For reshape op compute the shape at dimension `dimIndex` of the output in
37 /// terms of shape of the `src`, when the reshape op is a collapsing
38 /// operation. It is the product of the shape of the collapsed dimensions of the
39 /// `src`.
40 static OpFoldResult
getCollapsedOutputDimFromInputShape(OpBuilder & builder,Location loc,int64_t dimIndex,Value src,ArrayRef<AffineMap> reassociationMap)41 getCollapsedOutputDimFromInputShape(OpBuilder &builder, Location loc,
42                                     int64_t dimIndex, Value src,
43                                     ArrayRef<AffineMap> reassociationMap) {
44   AffineMap map = reassociationMap[dimIndex];
45   unsigned startPos =
46       map.getResults().front().cast<AffineDimExpr>().getPosition();
47   unsigned endPos = map.getResults().back().cast<AffineDimExpr>().getPosition();
48   AffineExpr expr;
49   SmallVector<Value, 2> dynamicDims;
50   for (auto dim : llvm::seq_inclusive(startPos, endPos)) {
51     dynamicDims.push_back(builder.createOrFold<tensor::DimOp>(loc, src, dim));
52     AffineExpr currExpr = builder.getAffineSymbolExpr(dim - startPos);
53     expr = (expr ? expr * currExpr : currExpr);
54   }
55   return applyMapToValues(builder, loc,
56                           AffineMap::get(0, endPos - startPos + 1, expr),
57                           dynamicDims)[0];
58 }
59 
60 /// Given the `src` of a collapsing reshape op and its reassociation maps,
61 /// compute the shape of the result of the reshape.
getCollapsedOutputShapeFromInputShape(OpBuilder & builder,Location loc,Value src,ArrayRef<int64_t> dstStaticShape,ArrayRef<AffineMap> reassociation)62 static SmallVector<OpFoldResult, 4> getCollapsedOutputShapeFromInputShape(
63     OpBuilder &builder, Location loc, Value src,
64     ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) {
65   return llvm::to_vector<4>(llvm::map_range(
66       llvm::seq<int64_t>(0, dstStaticShape.size()), [&](int64_t dim) {
67         return getCollapsedOutputDimFromInputShape(builder, loc, dim, src,
68                                                    reassociation);
69       }));
70 }
71 
72 /// For an expanding reshape op, compute the value for a dimension of the output
73 /// from the shape of the input.
getExpandedOutputDimFromInputShape(OpBuilder & builder,Location loc,int64_t dimIndex,Value src,ArrayRef<int64_t> dstStaticShape,ArrayRef<AffineMap> reassociation,llvm::DenseMap<int64_t,int64_t> & expandedDimToCollapsedDim)74 static OpFoldResult getExpandedOutputDimFromInputShape(
75     OpBuilder &builder, Location loc, int64_t dimIndex, Value src,
76     ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation,
77     llvm::DenseMap<int64_t, int64_t> &expandedDimToCollapsedDim) {
78   if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) {
79     return builder.getI64IntegerAttr(dstStaticShape[dimIndex]);
80   }
81   unsigned sourceDimPos = expandedDimToCollapsedDim[dimIndex];
82   unsigned startPos = reassociation[sourceDimPos]
83                           .getResults()
84                           .front()
85                           .cast<AffineDimExpr>()
86                           .getPosition();
87   unsigned endPos = reassociation[sourceDimPos]
88                         .getResults()
89                         .back()
90                         .cast<AffineDimExpr>()
91                         .getPosition();
92   int64_t linearizedStaticDim = 1;
93   for (auto &d :
94        llvm::enumerate(dstStaticShape.slice(startPos, endPos - startPos + 1))) {
95     if (d.index() + startPos == static_cast<unsigned>(dimIndex))
96       continue;
97     assert(!ShapedType::isDynamic(d.value()) &&
98            "single dimension cannot be expanded into multiple dynamic "
99            "dimensions");
100     linearizedStaticDim *= d.value();
101   }
102   Value sourceDim = builder.create<tensor::DimOp>(loc, src, sourceDimPos);
103   return applyMapToValues(
104       builder, loc,
105       AffineMap::get(
106           0, 1, builder.getAffineSymbolExpr(0).floorDiv(linearizedStaticDim)),
107       sourceDim)[0];
108 }
109 
110 /// Given the `src` of an expanding reshape op, the reassociation maps and the
111 /// result type, compute the shape of the result of the reshape.
getExpandedOutputShapeFromInputShape(OpBuilder & builder,Location loc,Value src,ArrayRef<int64_t> dstStaticShape,ArrayRef<AffineMap> reassociation)112 static SmallVector<OpFoldResult, 4> getExpandedOutputShapeFromInputShape(
113     OpBuilder &builder, Location loc, Value src,
114     ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) {
115   llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim =
116       getExpandedDimToCollapsedDimMap(reassociation);
117   return llvm::to_vector<4>(llvm::map_range(
118       llvm::seq<int64_t>(0, dstStaticShape.size()), [&](int64_t dim) {
119         return getExpandedOutputDimFromInputShape(builder, loc, dim, src,
120                                                   dstStaticShape, reassociation,
121                                                   expandedDimToCollapsedDim);
122       }));
123 }
124 
125 static SmallVector<OpFoldResult, 4>
getReshapeOutputShapeFromInputShape(OpBuilder & builder,Location loc,Value src,ArrayRef<int64_t> dstStaticShape,ArrayRef<AffineMap> reassocation)126 getReshapeOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src,
127                                     ArrayRef<int64_t> dstStaticShape,
128                                     ArrayRef<AffineMap> reassocation) {
129   return dstStaticShape.size() >
130                  static_cast<size_t>(src.getType().cast<ShapedType>().getRank())
131              ? getExpandedOutputShapeFromInputShape(
132                    builder, loc, src, dstStaticShape, reassocation)
133              : getCollapsedOutputShapeFromInputShape(
134                    builder, loc, src, dstStaticShape, reassocation);
135 }
136 
137 /// Helper function to convert a vector of `OpFoldResult`s into a vector of
138 /// `Value`s.
getAsValues(OpBuilder & b,Location loc,ArrayRef<OpFoldResult> valueOrAttrVec)139 static SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
140                                       ArrayRef<OpFoldResult> valueOrAttrVec) {
141   return llvm::to_vector<4>(
142       llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
143         return getValueOrCreateConstantIndexOp(b, loc, value);
144       }));
145 }
146 
147 template <typename OpTy>
148 struct ReifyExpandOrCollapseShapeOp
149     : public ReifyRankedShapedTypeOpInterface::ExternalModel<
150           ReifyExpandOrCollapseShapeOp<OpTy>, OpTy> {
151   LogicalResult
reifyResultShapesReifyExpandOrCollapseShapeOp152   reifyResultShapes(Operation *op, OpBuilder &b,
153                     ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
154     auto loc = op->getLoc();
155     auto reshapeOp = cast<OpTy>(op);
156     auto resultShape = getReshapeOutputShapeFromInputShape(
157         b, loc, reshapeOp.getSrc(), reshapeOp.getResultType().getShape(),
158         reshapeOp.getReassociationMaps());
159     reifiedReturnShapes.push_back(getAsValues(b, loc, resultShape));
160     return success();
161   }
162 };
163 
164 namespace {
165 
166 struct ReifyPadOp
167     : public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyPadOp,
168                                                              PadOp> {
169   LogicalResult
reifyResultShapes__anonf4e57e250411::ReifyPadOp170   reifyResultShapes(Operation *op, OpBuilder &b,
171                     ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
172     auto padOp = cast<PadOp>(op);
173     Location loc = padOp.getLoc();
174     auto lowPad = padOp.getMixedLowPad();
175     auto highPad = padOp.getMixedHighPad();
176     SmallVector<Value> shapes;
177     for (auto dim : llvm::seq<int64_t>(0, padOp.getSourceType().getRank())) {
178       // Shape along each dimension is source dim + low pad + high pad.
179       SmallVector<Value> mapOperands;
180       mapOperands.push_back(
181           b.createOrFold<tensor::DimOp>(loc, padOp.getSource(), dim));
182       AffineExpr expr = b.getAffineDimExpr(0);
183       unsigned numSymbols = 0;
184       auto addOpFoldResult = [&](OpFoldResult valueOrAttr) {
185         if (Value v = valueOrAttr.dyn_cast<Value>()) {
186           expr = expr + b.getAffineSymbolExpr(numSymbols++);
187           mapOperands.push_back(v);
188           return;
189         }
190         int64_t staticValue =
191             valueOrAttr.get<Attribute>().cast<IntegerAttr>().getInt();
192         expr = expr + staticValue;
193       };
194       addOpFoldResult(lowPad[dim]);
195       addOpFoldResult(highPad[dim]);
196       shapes.push_back(applyMapToValues(
197           b, loc, AffineMap::get(1, numSymbols, expr), mapOperands)[0]);
198     }
199     reifiedReturnShapes.emplace_back(std::move(shapes));
200     return success();
201   }
202 };
203 
204 } // namespace
205 
registerInferTypeOpInterfaceExternalModels(DialectRegistry & registry)206 void mlir::tensor::registerInferTypeOpInterfaceExternalModels(
207     DialectRegistry &registry) {
208   registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
209     ExpandShapeOp::attachInterface<
210         ReifyExpandOrCollapseShapeOp<tensor::ExpandShapeOp>>(*ctx);
211     CollapseShapeOp::attachInterface<
212         ReifyExpandOrCollapseShapeOp<tensor::CollapseShapeOp>>(*ctx);
213     PadOp::attachInterface<ReifyPadOp>(*ctx);
214   });
215 }
216