1 //===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the definitions of the infer op interfaces defined in
10 // `InferTypeOpInterface.td`.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Interfaces/InferTypeOpInterface.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Matchers.h"
17 #include "llvm/Support/FormatVariadic.h"
18 
19 using namespace mlir;
20 
21 namespace mlir {
22 #include "mlir/Interfaces/InferTypeOpInterface.cpp.inc"
23 } // namespace mlir
24 
25 bool ShapeAdaptor::hasRank() const {
26   if (val.isNull())
27     return false;
28   if (auto t = val.dyn_cast<Type>())
29     return t.cast<ShapedType>().hasRank();
30   if (val.is<Attribute>())
31     return true;
32   return val.get<ShapedTypeComponents *>()->hasRank();
33 }
34 
35 Type ShapeAdaptor::getElementType() const {
36   if (val.isNull())
37     return nullptr;
38   if (auto t = val.dyn_cast<Type>())
39     return t.cast<ShapedType>().getElementType();
40   if (val.is<Attribute>())
41     return nullptr;
42   return val.get<ShapedTypeComponents *>()->getElementType();
43 }
44 
45 void ShapeAdaptor::getDims(SmallVectorImpl<int64_t> &res) const {
46   assert(hasRank());
47   if (auto t = val.dyn_cast<Type>()) {
48     ArrayRef<int64_t> vals = t.cast<ShapedType>().getShape();
49     res.assign(vals.begin(), vals.end());
50   } else if (auto attr = val.dyn_cast<Attribute>()) {
51     auto dattr = attr.cast<DenseIntElementsAttr>();
52     res.clear();
53     res.reserve(dattr.size());
54     for (auto it : dattr.getValues<APInt>())
55       res.push_back(it.getSExtValue());
56   } else {
57     auto vals = val.get<ShapedTypeComponents *>()->getDims();
58     res.assign(vals.begin(), vals.end());
59   }
60 }
61 
62 void ShapeAdaptor::getDims(ShapedTypeComponents &res) const {
63   assert(hasRank());
64   res.ranked = true;
65   getDims(res.dims);
66 }
67 
68 int64_t ShapeAdaptor::getDimSize(int index) const {
69   assert(hasRank());
70   if (auto t = val.dyn_cast<Type>())
71     return t.cast<ShapedType>().getDimSize(index);
72   if (auto attr = val.dyn_cast<Attribute>())
73     return attr.cast<DenseIntElementsAttr>()
74         .getValues<APInt>()[index]
75         .getSExtValue();
76   auto *stc = val.get<ShapedTypeComponents *>();
77   return stc->getDims()[index];
78 }
79 
80 int64_t ShapeAdaptor::getRank() const {
81   assert(hasRank());
82   if (auto t = val.dyn_cast<Type>())
83     return t.cast<ShapedType>().getRank();
84   if (auto attr = val.dyn_cast<Attribute>())
85     return attr.cast<DenseIntElementsAttr>().size();
86   return val.get<ShapedTypeComponents *>()->getDims().size();
87 }
88 
89 bool ShapeAdaptor::hasStaticShape() const {
90   if (!hasRank())
91     return false;
92 
93   if (auto t = val.dyn_cast<Type>())
94     return t.cast<ShapedType>().hasStaticShape();
95   if (auto attr = val.dyn_cast<Attribute>()) {
96     auto dattr = attr.cast<DenseIntElementsAttr>();
97     for (auto index : dattr.getValues<APInt>())
98       if (ShapedType::isDynamic(index.getSExtValue()))
99         return false;
100     return true;
101   }
102   auto *stc = val.get<ShapedTypeComponents *>();
103   for (int64_t dim : stc->getDims())
104     if (ShapedType::isDynamic(dim))
105       return false;
106   return true;
107 }
108 
109 int64_t ShapeAdaptor::getNumElements() const {
110   assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
111 
112   if (auto t = val.dyn_cast<Type>())
113     return t.cast<ShapedType>().getNumElements();
114 
115   if (auto attr = val.dyn_cast<Attribute>()) {
116     auto dattr = attr.cast<DenseIntElementsAttr>();
117     int64_t num = 1;
118     for (auto index : dattr.getValues<APInt>()) {
119       num *= index.getZExtValue();
120       assert(num >= 0 && "integer overflow in element count computation");
121     }
122     return num;
123   }
124 
125   auto *stc = val.get<ShapedTypeComponents *>();
126   int64_t num = 1;
127   for (int64_t dim : stc->getDims()) {
128     num *= dim;
129     assert(num >= 0 && "integer overflow in element count computation");
130   }
131   return num;
132 }
133 
134 void ShapeAdaptor::dump() const {
135   if (!hasRank()) {
136     llvm::errs() << "<<unranked>>\n";
137     return;
138   }
139 
140   SmallVector<int64_t> dims;
141   getDims(dims);
142   auto mapped = llvm::map_range(dims, [](int64_t dim) -> std::string {
143     if (ShapedType::isDynamic(dim))
144       return "?";
145     return llvm::formatv("{0}", dim).str();
146   });
147   llvm::errs() << "rank = " << getRank() << " dims = [";
148   llvm::interleave(mapped, llvm::errs(), "x");
149   llvm::errs() << "]\n";
150 }
151 
152 ShapeAdaptor ValueShapeRange::getValueAsShape(int index) {
153   Value val = operator[](index);
154   if (valueToShape)
155     if (ShapeAdaptor ret = valueToShape(val))
156       return ret;
157 
158   DenseIntElementsAttr attr;
159   if (!matchPattern(val, m_Constant(&attr)))
160     return nullptr;
161   if (attr.getType().getRank() != 1)
162     return nullptr;
163   return attr;
164 }
165 
166 ShapeAdaptor ValueShapeRange::getShape(Value val) const {
167   if (operandShape)
168     if (ShapeAdaptor ret = operandShape(val))
169       return ret;
170   return val.getType();
171 }
172 
173 ShapeAdaptor ValueShapeRange::getShape(int index) const {
174   if (index < 0 || static_cast<size_t>(index) >= size())
175     return nullptr;
176   return getShape(operator[](index));
177 }
178 
179 LogicalResult mlir::detail::inferReturnTensorTypes(
180     function_ref<LogicalResult(
181         MLIRContext *, Optional<Location> location, ValueShapeRange operands,
182         DictionaryAttr attributes, RegionRange regions,
183         SmallVectorImpl<ShapedTypeComponents> &retComponents)>
184         componentTypeFn,
185     MLIRContext *context, Optional<Location> location, ValueRange operands,
186     DictionaryAttr attributes, RegionRange regions,
187     SmallVectorImpl<Type> &inferredReturnTypes) {
188   SmallVector<ShapedTypeComponents, 2> retComponents;
189   if (failed(componentTypeFn(context, location, operands, attributes, regions,
190                              retComponents)))
191     return failure();
192   for (const auto &shapeAndType : retComponents) {
193     assert(shapeAndType.getAttribute() == nullptr && "attribute not supported");
194     if (shapeAndType.hasRank())
195       inferredReturnTypes.push_back(RankedTensorType::get(
196           shapeAndType.getDims(), shapeAndType.getElementType()));
197     else
198       inferredReturnTypes.push_back(
199           UnrankedTensorType::get(shapeAndType.getElementType()));
200   }
201   return success();
202 }
203 
204 LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) {
205   SmallVector<Type, 4> inferredReturnTypes;
206   auto retTypeFn = cast<InferTypeOpInterface>(op);
207   if (failed(retTypeFn.inferReturnTypes(
208           op->getContext(), op->getLoc(), op->getOperands(),
209           op->getAttrDictionary(), op->getRegions(), inferredReturnTypes)))
210     return failure();
211   if (!retTypeFn.isCompatibleReturnTypes(inferredReturnTypes,
212                                          op->getResultTypes()))
213     return op->emitOpError("inferred type(s) ")
214            << inferredReturnTypes
215            << " are incompatible with return type(s) of operation "
216            << op->getResultTypes();
217   return success();
218 }
219