1 //===- InferTypeOpInterface.h - Infer Type Interfaces -----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains the definitions of the infer op interfaces defined in 10 // `InferTypeOpInterface.td`. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_ 15 #define MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_ 16 17 #include "mlir/IR/Attributes.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/BuiltinTypes.h" 20 #include "mlir/IR/Location.h" 21 #include "mlir/IR/OpDefinition.h" 22 #include "mlir/Support/LLVM.h" 23 #include "llvm/ADT/PointerUnion.h" 24 #include "llvm/ADT/SmallVector.h" 25 26 namespace mlir { 27 28 class ShapedTypeComponents; 29 using ReifiedRankedShapedTypeDims = SmallVector<SmallVector<Value>>; 30 31 /// Adaptor class to abstract the differences between whether value is from 32 /// a ShapedType or ShapedTypeComponents or DenseIntElementsAttribute. 33 class ShapeAdaptor { 34 public: ShapeAdaptor(Type t)35 ShapeAdaptor(Type t) { 36 if (auto st = t.dyn_cast<ShapedType>()) 37 val = st; 38 } ShapeAdaptor(Attribute t)39 ShapeAdaptor(Attribute t) { 40 if (auto da = t.dyn_cast<DenseIntElementsAttr>()) 41 val = da; 42 } ShapeAdaptor(ShapedTypeComponents * components)43 ShapeAdaptor(ShapedTypeComponents *components) : val(components) {} ShapeAdaptor(ShapedTypeComponents & components)44 ShapeAdaptor(ShapedTypeComponents &components) : val(&components) {} 45 46 /// Returns whether the shape has a rank. 47 bool hasRank() const; 48 49 /// Returns the element type. 50 Type getElementType() const; 51 52 /// Populates the dimensions from shape referenced. 53 /// Requires: shape is ranked. 54 void getDims(SmallVectorImpl<int64_t> &res) const; 55 56 /// Populates the dimensions of the ShapeTypeComponents. 57 /// Requires: shape is ranked. 58 void getDims(ShapedTypeComponents &res) const; 59 60 /// Returns the size of the index'th dimension. 61 /// Requires: shape is ranked. 62 int64_t getDimSize(int index) const; 63 64 /// Returns whether the index'th dimension is dynamic. 65 /// Requires: shape is ranked. isDynamicDim(int index)66 bool isDynamicDim(int index) const { 67 return ShapedType::isDynamic(getDimSize(index)); 68 } 69 70 /// Returns whether the shape is fully static. 71 bool hasStaticShape() const; 72 73 /// Returns the rank of the shape. 74 /// Requires: shape is ranked. 75 int64_t getRank() const; 76 77 /// Returns the number of elements in the shape. 78 /// Requires: hasStaticShape 79 int64_t getNumElements() const; 80 81 /// Returns whether valid (non-null) shape. 82 explicit operator bool() const { return !val.isNull(); } 83 84 /// Dumps textual repesentation to stderr. 85 void dump() const; 86 87 private: 88 // Union storing either ShapedTypeComponents, ShapedType (stored as Type and 89 // casted), or DenseIntElementsAttribute (stored as Atrtribute). 90 PointerUnion<ShapedTypeComponents *, Type, Attribute> val = nullptr; 91 }; 92 93 /// ShapedTypeComponents that represents the components of a ShapedType. 94 /// The components consist of 95 /// - A ranked or unranked shape with the dimension specification match those 96 /// of ShapeType's getShape() (e.g., dynamic dimension represented using 97 /// ShapedType::kDynamicSize) 98 /// - A element type, may be unset (nullptr) 99 /// - A attribute, may be unset (nullptr) 100 /// Used by ShapedType type inferences. 101 class ShapedTypeComponents { 102 /// Internal storage type for shape. 103 using ShapeStorageT = SmallVector<int64_t, 3>; 104 105 public: 106 /// Default construction is an unranked shape. ShapedTypeComponents()107 ShapedTypeComponents() : elementType(nullptr), attr(nullptr){}; ShapedTypeComponents(Type elementType)108 ShapedTypeComponents(Type elementType) 109 : elementType(elementType), attr(nullptr), ranked(false) {} ShapedTypeComponents(ShapedType shapedType)110 ShapedTypeComponents(ShapedType shapedType) : attr(nullptr) { 111 ranked = shapedType.hasRank(); 112 elementType = shapedType.getElementType(); 113 if (ranked) 114 dims = llvm::to_vector<4>(shapedType.getShape()); 115 } ShapedTypeComponents(ShapeAdaptor adaptor)116 ShapedTypeComponents(ShapeAdaptor adaptor) : attr(nullptr) { 117 ranked = adaptor.hasRank(); 118 elementType = adaptor.getElementType(); 119 if (ranked) 120 adaptor.getDims(*this); 121 } 122 template <typename Arg, typename = typename std::enable_if_t< 123 std::is_constructible<ShapeStorageT, Arg>::value>> 124 ShapedTypeComponents(Arg &&arg, Type elementType = nullptr, 125 Attribute attr = nullptr) dims(std::forward<Arg> (arg))126 : dims(std::forward<Arg>(arg)), elementType(elementType), attr(attr), 127 ranked(true) {} 128 ShapedTypeComponents(ArrayRef<int64_t> vec, Type elementType = nullptr, 129 Attribute attr = nullptr) 130 : dims(vec.begin(), vec.end()), elementType(elementType), attr(attr), 131 ranked(true) {} 132 133 /// Return the dimensions of the shape. 134 /// Requires: shape is ranked. getDims()135 ArrayRef<int64_t> getDims() const { 136 assert(ranked && "requires ranked shape"); 137 return dims; 138 } 139 140 /// Return whether the shape has a rank. hasRank()141 bool hasRank() const { return ranked; }; 142 143 /// Return the element type component. getElementType()144 Type getElementType() const { return elementType; }; 145 146 /// Return the raw attribute component. getAttribute()147 Attribute getAttribute() const { return attr; }; 148 149 private: 150 friend class ShapeAdaptor; 151 152 ShapeStorageT dims; 153 Type elementType; 154 Attribute attr; 155 bool ranked{false}; 156 }; 157 158 /// Range of values and shapes (corresponding effectively to Shapes dialect's 159 /// ValueShape type concept). 160 // Currently this exposes the Value (of operands) and Type of the Value. This is 161 // not ideal as then one can accidentally reference an out of date shape. This 162 // is done to both enable gradual switch and also as OpAdaptor doesn't currently 163 // allow returning anything other than Value. 164 class ValueShapeRange : public ValueRange::RangeBaseT { 165 public: 166 using ValueShapeMapFn = function_ref<ShapeAdaptor(Value)>; 167 168 ValueShapeRange(ValueRange values, ValueShapeMapFn operandShape = nullptr, 169 ValueShapeMapFn valueToShape = nullptr) RangeBaseT(values)170 : RangeBaseT(values), operandShape(operandShape), 171 valueToShape(valueToShape) {} ValueShapeRange(const std::initializer_list<Value> & values)172 ValueShapeRange(const std::initializer_list<Value> &values) 173 : ValueShapeRange(ValueRange(values)) {} 174 175 ValueShapeRange(const ValueShapeRange &) = default; 176 177 /// Sets the Value to ShapeAdaptor mapping function and returns this. setValueToShapeMapping(ValueShapeMapFn fn)178 ValueShapeRange &setValueToShapeMapping(ValueShapeMapFn fn) { 179 valueToShape = fn; 180 return *this; 181 } 182 setOperandShapeMapping(ValueShapeMapFn fn)183 ValueShapeRange &setOperandShapeMapping(ValueShapeMapFn fn) { 184 operandShape = fn; 185 return *this; 186 } 187 188 /// Returns the set Value to ShapeAdaptor mapping function. getValueToShapeMapping()189 ValueShapeMapFn getValueToShapeMapping() const { return valueToShape; } getOperandShapeMapping()190 ValueShapeMapFn getOperandShapeMapping() const { return operandShape; } 191 192 // Accessors. 193 194 /// Returns the types of the values within this range. 195 /// Note: This returns only the types of Values in the ValueRange and not a 196 /// more refined type. 197 using type_iterator = ValueTypeIterator<iterator>; 198 using type_range = ValueTypeRange<ValueRange>; getTypes()199 type_range getTypes() const { return {begin(), end()}; } getType()200 auto getType() const { return getTypes(); } 201 202 /// Returns the Values in the ValueRange. 203 /// To query the most up to date shape of a Value, query the shape 204 /// using getShape below rather than using the type of the Value. getValues()205 ValueRange getValues() const { return ValueRange(begin(), end()); }; 206 207 /// Returns an argument as shape. If the argument is not constant or not a 208 /// shape, then the function returns a nullptr. 209 /// This will first query the valueToShape mapping (if set), before querying 210 /// the ValueRange. 211 ShapeAdaptor getValueAsShape(int index); 212 213 /// Returns the shape of index'th operand. 214 // TODO: Update so that operator[] references these instead to avoid 215 // accidentally refering to less refined shape. 216 ShapeAdaptor getShape(int index) const; 217 218 /// Returns the shape of the given Value. 219 ShapeAdaptor getShape(Value val) const; 220 221 private: 222 // Mapping from Value to ShapedTypeComponents corresponding to shape of type 223 // of Value. 224 ValueShapeMapFn operandShape; 225 226 // Mapping from Value to ShapedTypeComponents corresponding to constant Value 227 // if interpreted as shape. 228 ValueShapeMapFn valueToShape; 229 }; 230 231 namespace detail { 232 // Helper function to infer return tensor returns types given element and 233 // shape inference function. 234 // 235 // TODO: Consider generating typedefs for trait member functions if this usage 236 // becomes more common. 237 LogicalResult inferReturnTensorTypes( 238 function_ref<LogicalResult( 239 MLIRContext *, Optional<Location> location, ValueShapeRange operands, 240 DictionaryAttr attributes, RegionRange regions, 241 SmallVectorImpl<ShapedTypeComponents> &retComponents)> 242 componentTypeFn, 243 MLIRContext *context, Optional<Location> location, ValueRange operands, 244 DictionaryAttr attributes, RegionRange regions, 245 SmallVectorImpl<Type> &inferredReturnTypes); 246 247 /// Verifies that the inferred result types match the actual result types for 248 /// the op. Precondition: op implements InferTypeOpInterface. 249 LogicalResult verifyInferredResultTypes(Operation *op); 250 } // namespace detail 251 252 namespace OpTrait { 253 template <typename ConcreteType> 254 class InferTensorType; 255 } // namespace OpTrait 256 } // namespace mlir 257 258 /// Include the generated interface declarations. 259 #include "mlir/Interfaces/InferTypeOpInterface.h.inc" 260 261 namespace mlir { 262 namespace OpTrait { 263 264 /// Tensor type inference trait that constructs a tensor from the inferred 265 /// shape and elemental types. 266 /// Requires: Op implements InferShapedTypeOpInterface and InferTypeOpInterface. 267 /// Less strict is possible (e.g., implements inferReturnTypeComponents and 268 /// these always populates all element types and shapes or fails, but this\ 269 /// trait is currently only used where the interfaces are, so keep it 270 /// restricted for now). 271 template <typename ConcreteType> 272 class InferTensorType : public TraitBase<ConcreteType, InferTensorType> { 273 public: 274 static LogicalResult inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)275 inferReturnTypes(MLIRContext *context, Optional<Location> location, 276 ValueRange operands, DictionaryAttr attributes, 277 RegionRange regions, 278 SmallVectorImpl<Type> &inferredReturnTypes) { 279 static_assert( 280 ConcreteType::template hasTrait<InferShapedTypeOpInterface::Trait>(), 281 "requires InferShapedTypeOpInterface to ensure succesful invocation"); 282 static_assert( 283 ConcreteType::template hasTrait<InferTypeOpInterface::Trait>(), 284 "requires InferTypeOpInterface to ensure succesful invocation"); 285 return ::mlir::detail::inferReturnTensorTypes( 286 ConcreteType::inferReturnTypeComponents, context, location, operands, 287 attributes, regions, inferredReturnTypes); 288 } 289 }; 290 291 } // namespace OpTrait 292 } // namespace mlir 293 294 #endif // MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_ 295