1 //===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the definitions of the infer op interfaces defined in
10 // `InferTypeOpInterface.td`.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Interfaces/InferTypeOpInterface.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Matchers.h"
17 #include "llvm/Support/FormatVariadic.h"
18 
19 using namespace mlir;
20 
21 namespace mlir {
22 #include "mlir/Interfaces/InferTypeOpInterface.cpp.inc"
23 } // namespace mlir
24 
hasRank() const25 bool ShapeAdaptor::hasRank() const {
26   if (val.isNull())
27     return false;
28   if (auto t = val.dyn_cast<Type>())
29     return t.cast<ShapedType>().hasRank();
30   if (val.is<Attribute>())
31     return true;
32   return val.get<ShapedTypeComponents *>()->hasRank();
33 }
34 
getElementType() const35 Type ShapeAdaptor::getElementType() const {
36   if (val.isNull())
37     return nullptr;
38   if (auto t = val.dyn_cast<Type>())
39     return t.cast<ShapedType>().getElementType();
40   if (val.is<Attribute>())
41     return nullptr;
42   return val.get<ShapedTypeComponents *>()->getElementType();
43 }
44 
getDims(SmallVectorImpl<int64_t> & res) const45 void ShapeAdaptor::getDims(SmallVectorImpl<int64_t> &res) const {
46   assert(hasRank());
47   if (auto t = val.dyn_cast<Type>()) {
48     ArrayRef<int64_t> vals = t.cast<ShapedType>().getShape();
49     res.assign(vals.begin(), vals.end());
50   } else if (auto attr = val.dyn_cast<Attribute>()) {
51     auto dattr = attr.cast<DenseIntElementsAttr>();
52     res.clear();
53     res.reserve(dattr.size());
54     for (auto it : dattr.getValues<APInt>())
55       res.push_back(it.getSExtValue());
56   } else {
57     auto vals = val.get<ShapedTypeComponents *>()->getDims();
58     res.assign(vals.begin(), vals.end());
59   }
60 }
61 
getDims(ShapedTypeComponents & res) const62 void ShapeAdaptor::getDims(ShapedTypeComponents &res) const {
63   assert(hasRank());
64   res.ranked = true;
65   getDims(res.dims);
66 }
67 
getDimSize(int index) const68 int64_t ShapeAdaptor::getDimSize(int index) const {
69   assert(hasRank());
70   if (auto t = val.dyn_cast<Type>())
71     return t.cast<ShapedType>().getDimSize(index);
72   if (auto attr = val.dyn_cast<Attribute>())
73     return attr.cast<DenseIntElementsAttr>()
74         .getValues<APInt>()[index]
75         .getSExtValue();
76   auto *stc = val.get<ShapedTypeComponents *>();
77   return stc->getDims()[index];
78 }
79 
getRank() const80 int64_t ShapeAdaptor::getRank() const {
81   assert(hasRank());
82   if (auto t = val.dyn_cast<Type>())
83     return t.cast<ShapedType>().getRank();
84   if (auto attr = val.dyn_cast<Attribute>())
85     return attr.cast<DenseIntElementsAttr>().size();
86   return val.get<ShapedTypeComponents *>()->getDims().size();
87 }
88 
hasStaticShape() const89 bool ShapeAdaptor::hasStaticShape() const {
90   if (!hasRank())
91     return false;
92 
93   if (auto t = val.dyn_cast<Type>())
94     return t.cast<ShapedType>().hasStaticShape();
95   if (auto attr = val.dyn_cast<Attribute>()) {
96     auto dattr = attr.cast<DenseIntElementsAttr>();
97     for (auto index : dattr.getValues<APInt>())
98       if (ShapedType::isDynamic(index.getSExtValue()))
99         return false;
100     return true;
101   }
102   auto *stc = val.get<ShapedTypeComponents *>();
103   return llvm::none_of(stc->getDims(), ShapedType::isDynamic);
104 }
105 
getNumElements() const106 int64_t ShapeAdaptor::getNumElements() const {
107   assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
108 
109   if (auto t = val.dyn_cast<Type>())
110     return t.cast<ShapedType>().getNumElements();
111 
112   if (auto attr = val.dyn_cast<Attribute>()) {
113     auto dattr = attr.cast<DenseIntElementsAttr>();
114     int64_t num = 1;
115     for (auto index : dattr.getValues<APInt>()) {
116       num *= index.getZExtValue();
117       assert(num >= 0 && "integer overflow in element count computation");
118     }
119     return num;
120   }
121 
122   auto *stc = val.get<ShapedTypeComponents *>();
123   int64_t num = 1;
124   for (int64_t dim : stc->getDims()) {
125     num *= dim;
126     assert(num >= 0 && "integer overflow in element count computation");
127   }
128   return num;
129 }
130 
dump() const131 void ShapeAdaptor::dump() const {
132   if (!hasRank()) {
133     llvm::errs() << "<<unranked>>\n";
134     return;
135   }
136 
137   SmallVector<int64_t> dims;
138   getDims(dims);
139   auto mapped = llvm::map_range(dims, [](int64_t dim) -> std::string {
140     if (ShapedType::isDynamic(dim))
141       return "?";
142     return llvm::formatv("{0}", dim).str();
143   });
144   llvm::errs() << "rank = " << getRank() << " dims = [";
145   llvm::interleave(mapped, llvm::errs(), "x");
146   llvm::errs() << "]\n";
147 }
148 
getValueAsShape(int index)149 ShapeAdaptor ValueShapeRange::getValueAsShape(int index) {
150   Value val = operator[](index);
151   if (valueToShape)
152     if (ShapeAdaptor ret = valueToShape(val))
153       return ret;
154 
155   DenseIntElementsAttr attr;
156   if (!matchPattern(val, m_Constant(&attr)))
157     return nullptr;
158   if (attr.getType().getRank() != 1)
159     return nullptr;
160   return attr;
161 }
162 
getShape(Value val) const163 ShapeAdaptor ValueShapeRange::getShape(Value val) const {
164   if (operandShape)
165     if (ShapeAdaptor ret = operandShape(val))
166       return ret;
167   return val.getType();
168 }
169 
getShape(int index) const170 ShapeAdaptor ValueShapeRange::getShape(int index) const {
171   if (index < 0 || static_cast<size_t>(index) >= size())
172     return nullptr;
173   return getShape(operator[](index));
174 }
175 
inferReturnTensorTypes(function_ref<LogicalResult (MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & retComponents)> componentTypeFn,MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)176 LogicalResult mlir::detail::inferReturnTensorTypes(
177     function_ref<LogicalResult(
178         MLIRContext *, Optional<Location> location, ValueShapeRange operands,
179         DictionaryAttr attributes, RegionRange regions,
180         SmallVectorImpl<ShapedTypeComponents> &retComponents)>
181         componentTypeFn,
182     MLIRContext *context, Optional<Location> location, ValueRange operands,
183     DictionaryAttr attributes, RegionRange regions,
184     SmallVectorImpl<Type> &inferredReturnTypes) {
185   SmallVector<ShapedTypeComponents, 2> retComponents;
186   if (failed(componentTypeFn(context, location, operands, attributes, regions,
187                              retComponents)))
188     return failure();
189   for (const auto &shapeAndType : retComponents) {
190     assert(shapeAndType.getAttribute() == nullptr && "attribute not supported");
191     assert(shapeAndType.getElementType() &&
192            "element type required to construct tensor");
193     if (shapeAndType.hasRank())
194       inferredReturnTypes.push_back(RankedTensorType::get(
195           shapeAndType.getDims(), shapeAndType.getElementType()));
196     else
197       inferredReturnTypes.push_back(
198           UnrankedTensorType::get(shapeAndType.getElementType()));
199   }
200   return success();
201 }
202 
verifyInferredResultTypes(Operation * op)203 LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) {
204   SmallVector<Type, 4> inferredReturnTypes(op->getResultTypes());
205   auto retTypeFn = cast<InferTypeOpInterface>(op);
206   return retTypeFn.refineReturnTypes(op->getContext(), op->getLoc(),
207                                      op->getOperands(), op->getAttrDictionary(),
208                                      op->getRegions(), inferredReturnTypes);
209 }
210