1 //===- InferTypeOpInterface.h - Infer Type Interfaces -----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the definitions of the infer op interfaces defined in
10 // `InferTypeOpInterface.td`.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_
15 #define MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_
16 
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/Location.h"
21 #include "mlir/IR/OpDefinition.h"
22 #include "mlir/Support/LLVM.h"
23 #include "llvm/ADT/PointerUnion.h"
24 #include "llvm/ADT/SmallVector.h"
25 
26 namespace mlir {
27 
28 class ShapedTypeComponents;
29 using ReifiedRankedShapedTypeDims = SmallVector<SmallVector<Value>>;
30 
31 /// Adaptor class to abstract the differences between whether value is from
32 /// a ShapedType or ShapedTypeComponents or DenseIntElementsAttribute.
33 class ShapeAdaptor {
34 public:
ShapeAdaptor(Type t)35   ShapeAdaptor(Type t) {
36     if (auto st = t.dyn_cast<ShapedType>())
37       val = st;
38   }
ShapeAdaptor(Attribute t)39   ShapeAdaptor(Attribute t) {
40     if (auto da = t.dyn_cast<DenseIntElementsAttr>())
41       val = da;
42   }
ShapeAdaptor(ShapedTypeComponents * components)43   ShapeAdaptor(ShapedTypeComponents *components) : val(components) {}
ShapeAdaptor(ShapedTypeComponents & components)44   ShapeAdaptor(ShapedTypeComponents &components) : val(&components) {}
45 
46   /// Returns whether the shape has a rank.
47   bool hasRank() const;
48 
49   /// Returns the element type.
50   Type getElementType() const;
51 
52   /// Populates the dimensions from shape referenced.
53   /// Requires: shape is ranked.
54   void getDims(SmallVectorImpl<int64_t> &res) const;
55 
56   /// Populates the dimensions of the ShapeTypeComponents.
57   /// Requires: shape is ranked.
58   void getDims(ShapedTypeComponents &res) const;
59 
60   /// Returns the size of the index'th dimension.
61   /// Requires: shape is ranked.
62   int64_t getDimSize(int index) const;
63 
64   /// Returns whether the index'th dimension is dynamic.
65   /// Requires: shape is ranked.
isDynamicDim(int index)66   bool isDynamicDim(int index) const {
67     return ShapedType::isDynamic(getDimSize(index));
68   }
69 
70   /// Returns whether the shape is fully static.
71   bool hasStaticShape() const;
72 
73   /// Returns the rank of the shape.
74   /// Requires: shape is ranked.
75   int64_t getRank() const;
76 
77   /// Returns the number of elements in the shape.
78   /// Requires: hasStaticShape
79   int64_t getNumElements() const;
80 
81   /// Returns whether valid (non-null) shape.
82   explicit operator bool() const { return !val.isNull(); }
83 
84   /// Dumps textual repesentation to stderr.
85   void dump() const;
86 
87 private:
88   // Union storing either ShapedTypeComponents, ShapedType (stored as Type and
89   // casted), or DenseIntElementsAttribute (stored as Atrtribute).
90   PointerUnion<ShapedTypeComponents *, Type, Attribute> val = nullptr;
91 };
92 
93 /// ShapedTypeComponents that represents the components of a ShapedType.
94 /// The components consist of
95 ///  - A ranked or unranked shape with the dimension specification match those
96 ///    of ShapeType's getShape() (e.g., dynamic dimension represented using
97 ///    ShapedType::kDynamicSize)
98 ///  - A element type, may be unset (nullptr)
99 ///  - A attribute, may be unset (nullptr)
100 /// Used by ShapedType type inferences.
101 class ShapedTypeComponents {
102   /// Internal storage type for shape.
103   using ShapeStorageT = SmallVector<int64_t, 3>;
104 
105 public:
106   /// Default construction is an unranked shape.
ShapedTypeComponents()107   ShapedTypeComponents() : elementType(nullptr), attr(nullptr){};
ShapedTypeComponents(Type elementType)108   ShapedTypeComponents(Type elementType)
109       : elementType(elementType), attr(nullptr), ranked(false) {}
ShapedTypeComponents(ShapedType shapedType)110   ShapedTypeComponents(ShapedType shapedType) : attr(nullptr) {
111     ranked = shapedType.hasRank();
112     elementType = shapedType.getElementType();
113     if (ranked)
114       dims = llvm::to_vector<4>(shapedType.getShape());
115   }
ShapedTypeComponents(ShapeAdaptor adaptor)116   ShapedTypeComponents(ShapeAdaptor adaptor) : attr(nullptr) {
117     ranked = adaptor.hasRank();
118     elementType = adaptor.getElementType();
119     if (ranked)
120       adaptor.getDims(*this);
121   }
122   template <typename Arg, typename = typename std::enable_if_t<
123                               std::is_constructible<ShapeStorageT, Arg>::value>>
124   ShapedTypeComponents(Arg &&arg, Type elementType = nullptr,
125                        Attribute attr = nullptr)
dims(std::forward<Arg> (arg))126       : dims(std::forward<Arg>(arg)), elementType(elementType), attr(attr),
127         ranked(true) {}
128   ShapedTypeComponents(ArrayRef<int64_t> vec, Type elementType = nullptr,
129                        Attribute attr = nullptr)
130       : dims(vec.begin(), vec.end()), elementType(elementType), attr(attr),
131         ranked(true) {}
132 
133   /// Return the dimensions of the shape.
134   /// Requires: shape is ranked.
getDims()135   ArrayRef<int64_t> getDims() const {
136     assert(ranked && "requires ranked shape");
137     return dims;
138   }
139 
140   /// Return whether the shape has a rank.
hasRank()141   bool hasRank() const { return ranked; };
142 
143   /// Return the element type component.
getElementType()144   Type getElementType() const { return elementType; };
145 
146   /// Return the raw attribute component.
getAttribute()147   Attribute getAttribute() const { return attr; };
148 
149 private:
150   friend class ShapeAdaptor;
151 
152   ShapeStorageT dims;
153   Type elementType;
154   Attribute attr;
155   bool ranked{false};
156 };
157 
158 /// Range of values and shapes (corresponding effectively to Shapes dialect's
159 /// ValueShape type concept).
160 // Currently this exposes the Value (of operands) and Type of the Value. This is
161 // not ideal as then one can accidentally reference an out of date shape. This
162 // is done to both enable gradual switch and also as OpAdaptor doesn't currently
163 // allow returning anything other than Value.
164 class ValueShapeRange : public ValueRange::RangeBaseT {
165 public:
166   using ValueShapeMapFn = function_ref<ShapeAdaptor(Value)>;
167 
168   ValueShapeRange(ValueRange values, ValueShapeMapFn operandShape = nullptr,
169                   ValueShapeMapFn valueToShape = nullptr)
RangeBaseT(values)170       : RangeBaseT(values), operandShape(operandShape),
171         valueToShape(valueToShape) {}
ValueShapeRange(const std::initializer_list<Value> & values)172   ValueShapeRange(const std::initializer_list<Value> &values)
173       : ValueShapeRange(ValueRange(values)) {}
174 
175   ValueShapeRange(const ValueShapeRange &) = default;
176 
177   /// Sets the Value to ShapeAdaptor mapping function and returns this.
setValueToShapeMapping(ValueShapeMapFn fn)178   ValueShapeRange &setValueToShapeMapping(ValueShapeMapFn fn) {
179     valueToShape = fn;
180     return *this;
181   }
182 
setOperandShapeMapping(ValueShapeMapFn fn)183   ValueShapeRange &setOperandShapeMapping(ValueShapeMapFn fn) {
184     operandShape = fn;
185     return *this;
186   }
187 
188   /// Returns the set Value to ShapeAdaptor mapping function.
getValueToShapeMapping()189   ValueShapeMapFn getValueToShapeMapping() const { return valueToShape; }
getOperandShapeMapping()190   ValueShapeMapFn getOperandShapeMapping() const { return operandShape; }
191 
192   // Accessors.
193 
194   /// Returns the types of the values within this range.
195   /// Note: This returns only the types of Values in the ValueRange and not a
196   /// more refined type.
197   using type_iterator = ValueTypeIterator<iterator>;
198   using type_range = ValueTypeRange<ValueRange>;
getTypes()199   type_range getTypes() const { return {begin(), end()}; }
getType()200   auto getType() const { return getTypes(); }
201 
202   /// Returns the Values in the ValueRange.
203   /// To query the most up to date shape of a Value, query the shape
204   /// using getShape below rather than using the type of the Value.
getValues()205   ValueRange getValues() const { return ValueRange(begin(), end()); };
206 
207   /// Returns an argument as shape. If the argument is not constant or not a
208   /// shape, then the function returns a nullptr.
209   /// This will first query the valueToShape mapping (if set), before querying
210   /// the ValueRange.
211   ShapeAdaptor getValueAsShape(int index);
212 
213   /// Returns the shape of index'th operand.
214   // TODO: Update so that operator[] references these instead to avoid
215   // accidentally refering to less refined shape.
216   ShapeAdaptor getShape(int index) const;
217 
218   /// Returns the shape of the given Value.
219   ShapeAdaptor getShape(Value val) const;
220 
221 private:
222   // Mapping from Value to ShapedTypeComponents corresponding to shape of type
223   // of Value.
224   ValueShapeMapFn operandShape;
225 
226   // Mapping from Value to ShapedTypeComponents corresponding to constant Value
227   // if interpreted as shape.
228   ValueShapeMapFn valueToShape;
229 };
230 
231 namespace detail {
232 // Helper function to infer return tensor returns types given element and
233 // shape inference function.
234 //
235 // TODO: Consider generating typedefs for trait member functions if this usage
236 // becomes more common.
237 LogicalResult inferReturnTensorTypes(
238     function_ref<LogicalResult(
239         MLIRContext *, Optional<Location> location, ValueShapeRange operands,
240         DictionaryAttr attributes, RegionRange regions,
241         SmallVectorImpl<ShapedTypeComponents> &retComponents)>
242         componentTypeFn,
243     MLIRContext *context, Optional<Location> location, ValueRange operands,
244     DictionaryAttr attributes, RegionRange regions,
245     SmallVectorImpl<Type> &inferredReturnTypes);
246 
247 /// Verifies that the inferred result types match the actual result types for
248 /// the op. Precondition: op implements InferTypeOpInterface.
249 LogicalResult verifyInferredResultTypes(Operation *op);
250 } // namespace detail
251 
252 namespace OpTrait {
253 template <typename ConcreteType>
254 class InferTensorType;
255 } // namespace OpTrait
256 } // namespace mlir
257 
258 /// Include the generated interface declarations.
259 #include "mlir/Interfaces/InferTypeOpInterface.h.inc"
260 
261 namespace mlir {
262 namespace OpTrait {
263 
264 /// Tensor type inference trait that constructs a tensor from the inferred
265 /// shape and elemental types.
266 /// Requires: Op implements InferShapedTypeOpInterface and InferTypeOpInterface.
267 ///   Less strict is possible (e.g., implements inferReturnTypeComponents and
268 ///   these always populates all element types and shapes or fails, but this\
269 ///   trait is currently only used where the interfaces are, so keep it
270 ///   restricted for now).
271 template <typename ConcreteType>
272 class InferTensorType : public TraitBase<ConcreteType, InferTensorType> {
273 public:
274   static LogicalResult
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)275   inferReturnTypes(MLIRContext *context, Optional<Location> location,
276                    ValueRange operands, DictionaryAttr attributes,
277                    RegionRange regions,
278                    SmallVectorImpl<Type> &inferredReturnTypes) {
279     static_assert(
280         ConcreteType::template hasTrait<InferShapedTypeOpInterface::Trait>(),
281         "requires InferShapedTypeOpInterface to ensure succesful invocation");
282     static_assert(
283         ConcreteType::template hasTrait<InferTypeOpInterface::Trait>(),
284         "requires InferTypeOpInterface to ensure succesful invocation");
285     return ::mlir::detail::inferReturnTensorTypes(
286         ConcreteType::inferReturnTypeComponents, context, location, operands,
287         attributes, regions, inferredReturnTypes);
288   }
289 };
290 
291 } // namespace OpTrait
292 } // namespace mlir
293 
294 #endif // MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_
295