1 //===- InterfaceSupport.h - MLIR Interface Support Classes ------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines several support classes for defining interfaces. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_SUPPORT_INTERFACESUPPORT_H 14 #define MLIR_SUPPORT_INTERFACESUPPORT_H 15 16 #include "mlir/Support/TypeID.h" 17 #include "llvm/ADT/ArrayRef.h" 18 #include "llvm/ADT/DenseMap.h" 19 #include "llvm/Support/TypeName.h" 20 21 namespace mlir { 22 namespace detail { 23 //===----------------------------------------------------------------------===// 24 // Interface 25 //===----------------------------------------------------------------------===// 26 27 /// This class represents an abstract interface. An interface is a simplified 28 /// mechanism for attaching concept based polymorphism to a class hierarchy. An 29 /// interface is comprised of two components: 30 /// * The derived interface class: This is what users interact with, and invoke 31 /// methods on. 32 /// * An interface `Trait` class: This is the class that is attached to the 33 /// object implementing the interface. It is the mechanism with which models 34 /// are specialized. 35 /// 36 /// Derived interfaces types must provide the following template types: 37 /// * ConcreteType: The CRTP derived type. 38 /// * ValueT: The opaque type the derived interface operates on. For example 39 /// `Operation*` for operation interfaces, or `Attribute` for 40 /// attribute interfaces. 41 /// * Traits: A class that contains definitions for a 'Concept' and a 'Model' 42 /// class. The 'Concept' class defines an abstract virtual interface, 43 /// where as the 'Model' class implements this interface for a 44 /// specific derived T type. Both of these classes *must* not contain 45 /// non-static data. A simple example is shown below: 46 /// 47 /// ```c++ 48 /// struct ExampleInterfaceTraits { 49 /// struct Concept { 50 /// virtual unsigned getNumInputs(T t) const = 0; 51 /// }; 52 /// template <typename DerivedT> class Model { 53 /// unsigned getNumInputs(T t) const final { 54 /// return cast<DerivedT>(t).getNumInputs(); 55 /// } 56 /// }; 57 /// }; 58 /// ``` 59 /// 60 /// * BaseType: A desired base type for the interface. This is a class that 61 /// provides that provides specific functionality for the `ValueT` 62 /// value. For instance the specific `Op` that will wrap the 63 /// `Operation*` for an `OpInterface`. 64 /// * BaseTrait: The base type for the interface trait. This is the base class 65 /// to use for the interface trait that will be attached to each 66 /// instance of `ValueT` that implements this interface. 67 /// 68 template <typename ConcreteType, typename ValueT, typename Traits, 69 typename BaseType, 70 template <typename, template <typename> class> class BaseTrait> 71 class Interface : public BaseType { 72 public: 73 using Concept = typename Traits::Concept; 74 template <typename T> 75 using Model = typename Traits::template Model<T>; 76 template <typename T> 77 using FallbackModel = typename Traits::template FallbackModel<T>; 78 using InterfaceBase = 79 Interface<ConcreteType, ValueT, Traits, BaseType, BaseTrait>; 80 template <typename T, typename U> 81 using ExternalModel = typename Traits::template ExternalModel<T, U>; 82 using ValueType = ValueT; 83 84 /// This is a special trait that registers a given interface with an object. 85 template <typename ConcreteT> 86 struct Trait : public BaseTrait<ConcreteT, Trait> { 87 using ModelT = Model<ConcreteT>; 88 89 /// Define an accessor for the ID of this interface. getInterfaceIDTrait90 static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); } 91 }; 92 93 /// Construct an interface from an instance of the value type. 94 Interface(ValueT t = ValueT()) BaseType(t)95 : BaseType(t), impl(t ? ConcreteType::getInterfaceFor(t) : nullptr) { 96 assert((!t || impl) && "expected value to provide interface instance"); 97 } Interface(std::nullptr_t)98 Interface(std::nullptr_t) : BaseType(ValueT()), impl(nullptr) {} 99 100 /// Construct an interface instance from a type that implements this 101 /// interface's trait. 102 template <typename T, typename std::enable_if_t< 103 std::is_base_of<Trait<T>, T>::value> * = nullptr> Interface(T t)104 Interface(T t) 105 : BaseType(t), impl(t ? ConcreteType::getInterfaceFor(t) : nullptr) { 106 assert((!t || impl) && "expected value to provide interface instance"); 107 } 108 109 /// Constructor for DenseMapInfo's empty key and tombstone key. Interface(ValueT t,std::nullptr_t)110 Interface(ValueT t, std::nullptr_t) : BaseType(t), impl(nullptr) {} 111 112 /// Support 'classof' by checking if the given object defines the concrete 113 /// interface. classof(ValueT t)114 static bool classof(ValueT t) { return ConcreteType::getInterfaceFor(t); } 115 116 /// Define an accessor for the ID of this interface. getInterfaceID()117 static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); } 118 119 protected: 120 /// Get the raw concept in the correct derived concept type. getImpl()121 const Concept *getImpl() const { return impl; } getImpl()122 Concept *getImpl() { return impl; } 123 124 private: 125 /// A pointer to the impl concept object. 126 Concept *impl; 127 }; 128 129 //===----------------------------------------------------------------------===// 130 // InterfaceMap 131 //===----------------------------------------------------------------------===// 132 133 /// Template utility that computes the number of elements within `T` that 134 /// satisfy the given predicate. 135 template <template <class> class Pred, size_t N, typename... Ts> 136 struct count_if_t_impl : public std::integral_constant<size_t, N> {}; 137 template <template <class> class Pred, size_t N, typename T, typename... Us> 138 struct count_if_t_impl<Pred, N, T, Us...> 139 : public std::integral_constant< 140 size_t, 141 count_if_t_impl<Pred, N + (Pred<T>::value ? 1 : 0), Us...>::value> {}; 142 template <template <class> class Pred, typename... Ts> 143 using count_if_t = count_if_t_impl<Pred, 0, Ts...>; 144 145 namespace { 146 /// Type trait indicating whether all template arguments are 147 /// trivially-destructible. 148 template <typename... Args> 149 struct all_trivially_destructible; 150 151 template <typename Arg, typename... Args> 152 struct all_trivially_destructible<Arg, Args...> { 153 static constexpr const bool value = 154 std::is_trivially_destructible<Arg>::value && 155 all_trivially_destructible<Args...>::value; 156 }; 157 158 template <> 159 struct all_trivially_destructible<> { 160 static constexpr const bool value = true; 161 }; 162 } // namespace 163 164 /// This class provides an efficient mapping between a given `Interface` type, 165 /// and a particular implementation of its concept. 166 class InterfaceMap { 167 /// Trait to check if T provides a static 'getInterfaceID' method. 168 template <typename T, typename... Args> 169 using has_get_interface_id = decltype(T::getInterfaceID()); 170 template <typename T> 171 using detect_get_interface_id = llvm::is_detected<has_get_interface_id, T>; 172 template <typename... Types> 173 using num_interface_types_t = count_if_t<detect_get_interface_id, Types...>; 174 175 public: 176 InterfaceMap(InterfaceMap &&) = default; 177 InterfaceMap &operator=(InterfaceMap &&rhs) { 178 for (auto &it : interfaces) 179 free(it.second); 180 interfaces = std::move(rhs.interfaces); 181 return *this; 182 } 183 ~InterfaceMap() { 184 for (auto &it : interfaces) 185 free(it.second); 186 } 187 188 /// Construct an InterfaceMap with the given set of template types. For 189 /// convenience given that object trait lists may contain other non-interface 190 /// types, not all of the types need to be interfaces. The provided types that 191 /// do not represent interfaces are not added to the interface map. 192 template <typename... Types> 193 static InterfaceMap get() { 194 // TODO: Use constexpr if here in C++17. 195 constexpr size_t numInterfaces = num_interface_types_t<Types...>::value; 196 if (numInterfaces == 0) 197 return InterfaceMap(); 198 199 std::array<std::pair<TypeID, void *>, numInterfaces> elements; 200 std::pair<TypeID, void *> *elementIt = elements.data(); 201 (void)elementIt; 202 (void)std::initializer_list<int>{ 203 0, (addModelAndUpdateIterator<Types>(elementIt), 0)...}; 204 return InterfaceMap(elements); 205 } 206 207 /// Returns an instance of the concept object for the given interface if it 208 /// was registered to this map, null otherwise. 209 template <typename T> 210 typename T::Concept *lookup() const { 211 return reinterpret_cast<typename T::Concept *>(lookup(T::getInterfaceID())); 212 } 213 214 /// Returns true if the interface map contains an interface for the given id. 215 bool contains(TypeID interfaceID) const { return lookup(interfaceID); } 216 217 /// Create an InterfaceMap given with the implementation of the interfaces. 218 /// The use of this constructor is in general discouraged in favor of 219 /// 'InterfaceMap::get<InterfaceA, ...>()'. 220 InterfaceMap(MutableArrayRef<std::pair<TypeID, void *>> elements); 221 222 /// Insert the given models as implementations of the corresponding interfaces 223 /// for the concrete attribute class. 224 template <typename... IfaceModels> 225 void insert() { 226 static_assert(all_trivially_destructible<IfaceModels...>::value, 227 "interface models must be trivially destructible"); 228 std::pair<TypeID, void *> elements[] = { 229 std::make_pair(IfaceModels::Interface::getInterfaceID(), 230 new (malloc(sizeof(IfaceModels))) IfaceModels())...}; 231 insert(elements); 232 } 233 234 private: 235 InterfaceMap() = default; 236 237 /// Assign the interface model of the type to the given opaque element 238 /// iterator and increment it. 239 template <typename T> 240 static inline std::enable_if_t<detect_get_interface_id<T>::value> 241 addModelAndUpdateIterator(std::pair<TypeID, void *> *&elementIt) { 242 *elementIt = {T::getInterfaceID(), new (malloc(sizeof(typename T::ModelT))) 243 typename T::ModelT()}; 244 ++elementIt; 245 } 246 /// Overload when `T` isn't an interface. 247 template <typename T> 248 static inline std::enable_if_t<!detect_get_interface_id<T>::value> 249 addModelAndUpdateIterator(std::pair<TypeID, void *> *&) {} 250 251 /// Insert the given set of interface models into the interface map. 252 void insert(ArrayRef<std::pair<TypeID, void *>> elements); 253 254 /// Compare two TypeID instances by comparing the underlying pointer. 255 static bool compare(TypeID lhs, TypeID rhs) { 256 return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer(); 257 } 258 259 /// Returns an instance of the concept object for the given interface id if it 260 /// was registered to this map, null otherwise. 261 void *lookup(TypeID id) const { 262 const auto *it = 263 llvm::lower_bound(interfaces, id, [](const auto &it, TypeID id) { 264 return compare(it.first, id); 265 }); 266 return (it != interfaces.end() && it->first == id) ? it->second : nullptr; 267 } 268 269 /// A list of interface instances, sorted by TypeID. 270 SmallVector<std::pair<TypeID, void *>> interfaces; 271 }; 272 273 template <typename ConcreteType, typename ValueT, typename Traits, 274 typename BaseType, 275 template <typename, template <typename> class> class BaseTrait> 276 void isInterfaceImpl( 277 Interface<ConcreteType, ValueT, Traits, BaseType, BaseTrait> &); 278 279 template <typename T> 280 using is_interface_t = decltype(isInterfaceImpl(std::declval<T &>())); 281 282 template <typename T> 283 using IsInterface = llvm::is_detected<is_interface_t, T>; 284 285 } // namespace detail 286 } // namespace mlir 287 288 namespace llvm { 289 290 template <typename T> 291 struct DenseMapInfo<T, std::enable_if_t<mlir::detail::IsInterface<T>::value>> { 292 using ValueTypeInfo = llvm::DenseMapInfo<typename T::ValueType>; 293 294 static T getEmptyKey() { return T(ValueTypeInfo::getEmptyKey(), nullptr); } 295 296 static T getTombstoneKey() { 297 return T(ValueTypeInfo::getTombstoneKey(), nullptr); 298 } 299 300 static unsigned getHashValue(T val) { 301 return ValueTypeInfo::getHashValue(val); 302 } 303 304 static bool isEqual(T lhs, T rhs) { return ValueTypeInfo::isEqual(lhs, rhs); } 305 }; 306 307 } // namespace llvm 308 309 #endif 310