1 //===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the definitions of the infer op interfaces defined in
10 // `InferTypeOpInterface.td`.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Interfaces/InferTypeOpInterface.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Matchers.h"
17 #include "llvm/Support/FormatVariadic.h"
18
19 using namespace mlir;
20
21 namespace mlir {
22 #include "mlir/Interfaces/InferTypeOpInterface.cpp.inc"
23 } // namespace mlir
24
hasRank() const25 bool ShapeAdaptor::hasRank() const {
26 if (val.isNull())
27 return false;
28 if (auto t = val.dyn_cast<Type>())
29 return t.cast<ShapedType>().hasRank();
30 if (val.is<Attribute>())
31 return true;
32 return val.get<ShapedTypeComponents *>()->hasRank();
33 }
34
getElementType() const35 Type ShapeAdaptor::getElementType() const {
36 if (val.isNull())
37 return nullptr;
38 if (auto t = val.dyn_cast<Type>())
39 return t.cast<ShapedType>().getElementType();
40 if (val.is<Attribute>())
41 return nullptr;
42 return val.get<ShapedTypeComponents *>()->getElementType();
43 }
44
getDims(SmallVectorImpl<int64_t> & res) const45 void ShapeAdaptor::getDims(SmallVectorImpl<int64_t> &res) const {
46 assert(hasRank());
47 if (auto t = val.dyn_cast<Type>()) {
48 ArrayRef<int64_t> vals = t.cast<ShapedType>().getShape();
49 res.assign(vals.begin(), vals.end());
50 } else if (auto attr = val.dyn_cast<Attribute>()) {
51 auto dattr = attr.cast<DenseIntElementsAttr>();
52 res.clear();
53 res.reserve(dattr.size());
54 for (auto it : dattr.getValues<APInt>())
55 res.push_back(it.getSExtValue());
56 } else {
57 auto vals = val.get<ShapedTypeComponents *>()->getDims();
58 res.assign(vals.begin(), vals.end());
59 }
60 }
61
getDims(ShapedTypeComponents & res) const62 void ShapeAdaptor::getDims(ShapedTypeComponents &res) const {
63 assert(hasRank());
64 res.ranked = true;
65 getDims(res.dims);
66 }
67
getDimSize(int index) const68 int64_t ShapeAdaptor::getDimSize(int index) const {
69 assert(hasRank());
70 if (auto t = val.dyn_cast<Type>())
71 return t.cast<ShapedType>().getDimSize(index);
72 if (auto attr = val.dyn_cast<Attribute>())
73 return attr.cast<DenseIntElementsAttr>()
74 .getValues<APInt>()[index]
75 .getSExtValue();
76 auto *stc = val.get<ShapedTypeComponents *>();
77 return stc->getDims()[index];
78 }
79
getRank() const80 int64_t ShapeAdaptor::getRank() const {
81 assert(hasRank());
82 if (auto t = val.dyn_cast<Type>())
83 return t.cast<ShapedType>().getRank();
84 if (auto attr = val.dyn_cast<Attribute>())
85 return attr.cast<DenseIntElementsAttr>().size();
86 return val.get<ShapedTypeComponents *>()->getDims().size();
87 }
88
hasStaticShape() const89 bool ShapeAdaptor::hasStaticShape() const {
90 if (!hasRank())
91 return false;
92
93 if (auto t = val.dyn_cast<Type>())
94 return t.cast<ShapedType>().hasStaticShape();
95 if (auto attr = val.dyn_cast<Attribute>()) {
96 auto dattr = attr.cast<DenseIntElementsAttr>();
97 for (auto index : dattr.getValues<APInt>())
98 if (ShapedType::isDynamic(index.getSExtValue()))
99 return false;
100 return true;
101 }
102 auto *stc = val.get<ShapedTypeComponents *>();
103 return llvm::none_of(stc->getDims(), ShapedType::isDynamic);
104 }
105
getNumElements() const106 int64_t ShapeAdaptor::getNumElements() const {
107 assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
108
109 if (auto t = val.dyn_cast<Type>())
110 return t.cast<ShapedType>().getNumElements();
111
112 if (auto attr = val.dyn_cast<Attribute>()) {
113 auto dattr = attr.cast<DenseIntElementsAttr>();
114 int64_t num = 1;
115 for (auto index : dattr.getValues<APInt>()) {
116 num *= index.getZExtValue();
117 assert(num >= 0 && "integer overflow in element count computation");
118 }
119 return num;
120 }
121
122 auto *stc = val.get<ShapedTypeComponents *>();
123 int64_t num = 1;
124 for (int64_t dim : stc->getDims()) {
125 num *= dim;
126 assert(num >= 0 && "integer overflow in element count computation");
127 }
128 return num;
129 }
130
dump() const131 void ShapeAdaptor::dump() const {
132 if (!hasRank()) {
133 llvm::errs() << "<<unranked>>\n";
134 return;
135 }
136
137 SmallVector<int64_t> dims;
138 getDims(dims);
139 auto mapped = llvm::map_range(dims, [](int64_t dim) -> std::string {
140 if (ShapedType::isDynamic(dim))
141 return "?";
142 return llvm::formatv("{0}", dim).str();
143 });
144 llvm::errs() << "rank = " << getRank() << " dims = [";
145 llvm::interleave(mapped, llvm::errs(), "x");
146 llvm::errs() << "]\n";
147 }
148
getValueAsShape(int index)149 ShapeAdaptor ValueShapeRange::getValueAsShape(int index) {
150 Value val = operator[](index);
151 if (valueToShape)
152 if (ShapeAdaptor ret = valueToShape(val))
153 return ret;
154
155 DenseIntElementsAttr attr;
156 if (!matchPattern(val, m_Constant(&attr)))
157 return nullptr;
158 if (attr.getType().getRank() != 1)
159 return nullptr;
160 return attr;
161 }
162
getShape(Value val) const163 ShapeAdaptor ValueShapeRange::getShape(Value val) const {
164 if (operandShape)
165 if (ShapeAdaptor ret = operandShape(val))
166 return ret;
167 return val.getType();
168 }
169
getShape(int index) const170 ShapeAdaptor ValueShapeRange::getShape(int index) const {
171 if (index < 0 || static_cast<size_t>(index) >= size())
172 return nullptr;
173 return getShape(operator[](index));
174 }
175
inferReturnTensorTypes(function_ref<LogicalResult (MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & retComponents)> componentTypeFn,MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)176 LogicalResult mlir::detail::inferReturnTensorTypes(
177 function_ref<LogicalResult(
178 MLIRContext *, Optional<Location> location, ValueShapeRange operands,
179 DictionaryAttr attributes, RegionRange regions,
180 SmallVectorImpl<ShapedTypeComponents> &retComponents)>
181 componentTypeFn,
182 MLIRContext *context, Optional<Location> location, ValueRange operands,
183 DictionaryAttr attributes, RegionRange regions,
184 SmallVectorImpl<Type> &inferredReturnTypes) {
185 SmallVector<ShapedTypeComponents, 2> retComponents;
186 if (failed(componentTypeFn(context, location, operands, attributes, regions,
187 retComponents)))
188 return failure();
189 for (const auto &shapeAndType : retComponents) {
190 assert(shapeAndType.getAttribute() == nullptr && "attribute not supported");
191 assert(shapeAndType.getElementType() &&
192 "element type required to construct tensor");
193 if (shapeAndType.hasRank())
194 inferredReturnTypes.push_back(RankedTensorType::get(
195 shapeAndType.getDims(), shapeAndType.getElementType()));
196 else
197 inferredReturnTypes.push_back(
198 UnrankedTensorType::get(shapeAndType.getElementType()));
199 }
200 return success();
201 }
202
verifyInferredResultTypes(Operation * op)203 LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) {
204 SmallVector<Type, 4> inferredReturnTypes(op->getResultTypes());
205 auto retTypeFn = cast<InferTypeOpInterface>(op);
206 return retTypeFn.refineReturnTypes(op->getContext(), op->getLoc(),
207 op->getOperands(), op->getAttrDictionary(),
208 op->getRegions(), inferredReturnTypes);
209 }
210