1 //===- InterfaceSupport.h - MLIR Interface Support Classes ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines several support classes for defining interfaces.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_SUPPORT_INTERFACESUPPORT_H
14 #define MLIR_SUPPORT_INTERFACESUPPORT_H
15 
16 #include "mlir/Support/TypeID.h"
17 #include "llvm/ADT/ArrayRef.h"
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/Support/TypeName.h"
20 
21 namespace mlir {
22 namespace detail {
23 //===----------------------------------------------------------------------===//
24 // Interface
25 //===----------------------------------------------------------------------===//
26 
27 /// This class represents an abstract interface. An interface is a simplified
28 /// mechanism for attaching concept based polymorphism to a class hierarchy. An
29 /// interface is comprised of two components:
30 /// * The derived interface class: This is what users interact with, and invoke
31 ///   methods on.
32 /// * An interface `Trait` class: This is the class that is attached to the
33 ///   object implementing the interface. It is the mechanism with which models
34 ///   are specialized.
35 ///
36 /// Derived interfaces types must provide the following template types:
37 /// * ConcreteType: The CRTP derived type.
38 /// * ValueT: The opaque type the derived interface operates on. For example
39 ///           `Operation*` for operation interfaces, or `Attribute` for
40 ///           attribute interfaces.
41 /// * Traits: A class that contains definitions for a 'Concept' and a 'Model'
42 ///           class. The 'Concept' class defines an abstract virtual interface,
43 ///           where as the 'Model' class implements this interface for a
44 ///           specific derived T type. Both of these classes *must* not contain
45 ///           non-static data. A simple example is shown below:
46 ///
47 /// ```c++
48 ///    struct ExampleInterfaceTraits {
49 ///      struct Concept {
50 ///        virtual unsigned getNumInputs(T t) const = 0;
51 ///      };
52 ///      template <typename DerivedT> class Model {
53 ///        unsigned getNumInputs(T t) const final {
54 ///          return cast<DerivedT>(t).getNumInputs();
55 ///        }
56 ///      };
57 ///    };
58 /// ```
59 ///
60 /// * BaseType: A desired base type for the interface. This is a class that
61 ///             provides that provides specific functionality for the `ValueT`
62 ///             value. For instance the specific `Op` that will wrap the
63 ///             `Operation*` for an `OpInterface`.
64 /// * BaseTrait: The base type for the interface trait. This is the base class
65 ///              to use for the interface trait that will be attached to each
66 ///              instance of `ValueT` that implements this interface.
67 ///
68 template <typename ConcreteType, typename ValueT, typename Traits,
69           typename BaseType,
70           template <typename, template <typename> class> class BaseTrait>
71 class Interface : public BaseType {
72 public:
73   using Concept = typename Traits::Concept;
74   template <typename T>
75   using Model = typename Traits::template Model<T>;
76   template <typename T>
77   using FallbackModel = typename Traits::template FallbackModel<T>;
78   using InterfaceBase =
79       Interface<ConcreteType, ValueT, Traits, BaseType, BaseTrait>;
80   template <typename T, typename U>
81   using ExternalModel = typename Traits::template ExternalModel<T, U>;
82   using ValueType = ValueT;
83 
84   /// This is a special trait that registers a given interface with an object.
85   template <typename ConcreteT>
86   struct Trait : public BaseTrait<ConcreteT, Trait> {
87     using ModelT = Model<ConcreteT>;
88 
89     /// Define an accessor for the ID of this interface.
getInterfaceIDTrait90     static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); }
91   };
92 
93   /// Construct an interface from an instance of the value type.
94   Interface(ValueT t = ValueT())
BaseType(t)95       : BaseType(t), impl(t ? ConcreteType::getInterfaceFor(t) : nullptr) {
96     assert((!t || impl) && "expected value to provide interface instance");
97   }
Interface(std::nullptr_t)98   Interface(std::nullptr_t) : BaseType(ValueT()), impl(nullptr) {}
99 
100   /// Construct an interface instance from a type that implements this
101   /// interface's trait.
102   template <typename T, typename std::enable_if_t<
103                             std::is_base_of<Trait<T>, T>::value> * = nullptr>
Interface(T t)104   Interface(T t)
105       : BaseType(t), impl(t ? ConcreteType::getInterfaceFor(t) : nullptr) {
106     assert((!t || impl) && "expected value to provide interface instance");
107   }
108 
109   /// Constructor for DenseMapInfo's empty key and tombstone key.
Interface(ValueT t,std::nullptr_t)110   Interface(ValueT t, std::nullptr_t) : BaseType(t), impl(nullptr) {}
111 
112   /// Support 'classof' by checking if the given object defines the concrete
113   /// interface.
classof(ValueT t)114   static bool classof(ValueT t) { return ConcreteType::getInterfaceFor(t); }
115 
116   /// Define an accessor for the ID of this interface.
getInterfaceID()117   static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); }
118 
119 protected:
120   /// Get the raw concept in the correct derived concept type.
getImpl()121   const Concept *getImpl() const { return impl; }
getImpl()122   Concept *getImpl() { return impl; }
123 
124 private:
125   /// A pointer to the impl concept object.
126   Concept *impl;
127 };
128 
129 //===----------------------------------------------------------------------===//
130 // InterfaceMap
131 //===----------------------------------------------------------------------===//
132 
133 /// Template utility that computes the number of elements within `T` that
134 /// satisfy the given predicate.
135 template <template <class> class Pred, size_t N, typename... Ts>
136 struct count_if_t_impl : public std::integral_constant<size_t, N> {};
137 template <template <class> class Pred, size_t N, typename T, typename... Us>
138 struct count_if_t_impl<Pred, N, T, Us...>
139     : public std::integral_constant<
140           size_t,
141           count_if_t_impl<Pred, N + (Pred<T>::value ? 1 : 0), Us...>::value> {};
142 template <template <class> class Pred, typename... Ts>
143 using count_if_t = count_if_t_impl<Pred, 0, Ts...>;
144 
145 namespace {
146 /// Type trait indicating whether all template arguments are
147 /// trivially-destructible.
148 template <typename... Args>
149 struct all_trivially_destructible;
150 
151 template <typename Arg, typename... Args>
152 struct all_trivially_destructible<Arg, Args...> {
153   static constexpr const bool value =
154       std::is_trivially_destructible<Arg>::value &&
155       all_trivially_destructible<Args...>::value;
156 };
157 
158 template <>
159 struct all_trivially_destructible<> {
160   static constexpr const bool value = true;
161 };
162 } // namespace
163 
164 /// This class provides an efficient mapping between a given `Interface` type,
165 /// and a particular implementation of its concept.
166 class InterfaceMap {
167   /// Trait to check if T provides a static 'getInterfaceID' method.
168   template <typename T, typename... Args>
169   using has_get_interface_id = decltype(T::getInterfaceID());
170   template <typename T>
171   using detect_get_interface_id = llvm::is_detected<has_get_interface_id, T>;
172   template <typename... Types>
173   using num_interface_types_t = count_if_t<detect_get_interface_id, Types...>;
174 
175 public:
176   InterfaceMap(InterfaceMap &&) = default;
177   InterfaceMap &operator=(InterfaceMap &&rhs) {
178     for (auto &it : interfaces)
179       free(it.second);
180     interfaces = std::move(rhs.interfaces);
181     return *this;
182   }
183   ~InterfaceMap() {
184     for (auto &it : interfaces)
185       free(it.second);
186   }
187 
188   /// Construct an InterfaceMap with the given set of template types. For
189   /// convenience given that object trait lists may contain other non-interface
190   /// types, not all of the types need to be interfaces. The provided types that
191   /// do not represent interfaces are not added to the interface map.
192   template <typename... Types>
193   static InterfaceMap get() {
194     // TODO: Use constexpr if here in C++17.
195     constexpr size_t numInterfaces = num_interface_types_t<Types...>::value;
196     if (numInterfaces == 0)
197       return InterfaceMap();
198 
199     std::array<std::pair<TypeID, void *>, numInterfaces> elements;
200     std::pair<TypeID, void *> *elementIt = elements.data();
201     (void)elementIt;
202     (void)std::initializer_list<int>{
203         0, (addModelAndUpdateIterator<Types>(elementIt), 0)...};
204     return InterfaceMap(elements);
205   }
206 
207   /// Returns an instance of the concept object for the given interface if it
208   /// was registered to this map, null otherwise.
209   template <typename T>
210   typename T::Concept *lookup() const {
211     return reinterpret_cast<typename T::Concept *>(lookup(T::getInterfaceID()));
212   }
213 
214   /// Returns true if the interface map contains an interface for the given id.
215   bool contains(TypeID interfaceID) const { return lookup(interfaceID); }
216 
217   /// Create an InterfaceMap given with the implementation of the interfaces.
218   /// The use of this constructor is in general discouraged in favor of
219   /// 'InterfaceMap::get<InterfaceA, ...>()'.
220   InterfaceMap(MutableArrayRef<std::pair<TypeID, void *>> elements);
221 
222   /// Insert the given models as implementations of the corresponding interfaces
223   /// for the concrete attribute class.
224   template <typename... IfaceModels>
225   void insert() {
226     static_assert(all_trivially_destructible<IfaceModels...>::value,
227                   "interface models must be trivially destructible");
228     std::pair<TypeID, void *> elements[] = {
229         std::make_pair(IfaceModels::Interface::getInterfaceID(),
230                        new (malloc(sizeof(IfaceModels))) IfaceModels())...};
231     insert(elements);
232   }
233 
234 private:
235   InterfaceMap() = default;
236 
237   /// Assign the interface model of the type to the given opaque element
238   /// iterator and increment it.
239   template <typename T>
240   static inline std::enable_if_t<detect_get_interface_id<T>::value>
241   addModelAndUpdateIterator(std::pair<TypeID, void *> *&elementIt) {
242     *elementIt = {T::getInterfaceID(), new (malloc(sizeof(typename T::ModelT)))
243                                            typename T::ModelT()};
244     ++elementIt;
245   }
246   /// Overload when `T` isn't an interface.
247   template <typename T>
248   static inline std::enable_if_t<!detect_get_interface_id<T>::value>
249   addModelAndUpdateIterator(std::pair<TypeID, void *> *&) {}
250 
251   /// Insert the given set of interface models into the interface map.
252   void insert(ArrayRef<std::pair<TypeID, void *>> elements);
253 
254   /// Compare two TypeID instances by comparing the underlying pointer.
255   static bool compare(TypeID lhs, TypeID rhs) {
256     return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer();
257   }
258 
259   /// Returns an instance of the concept object for the given interface id if it
260   /// was registered to this map, null otherwise.
261   void *lookup(TypeID id) const {
262     const auto *it =
263         llvm::lower_bound(interfaces, id, [](const auto &it, TypeID id) {
264           return compare(it.first, id);
265         });
266     return (it != interfaces.end() && it->first == id) ? it->second : nullptr;
267   }
268 
269   /// A list of interface instances, sorted by TypeID.
270   SmallVector<std::pair<TypeID, void *>> interfaces;
271 };
272 
273 template <typename ConcreteType, typename ValueT, typename Traits,
274           typename BaseType,
275           template <typename, template <typename> class> class BaseTrait>
276 void isInterfaceImpl(
277     Interface<ConcreteType, ValueT, Traits, BaseType, BaseTrait> &);
278 
279 template <typename T>
280 using is_interface_t = decltype(isInterfaceImpl(std::declval<T &>()));
281 
282 template <typename T>
283 using IsInterface = llvm::is_detected<is_interface_t, T>;
284 
285 } // namespace detail
286 } // namespace mlir
287 
288 namespace llvm {
289 
290 template <typename T>
291 struct DenseMapInfo<T, std::enable_if_t<mlir::detail::IsInterface<T>::value>> {
292   using ValueTypeInfo = llvm::DenseMapInfo<typename T::ValueType>;
293 
294   static T getEmptyKey() { return T(ValueTypeInfo::getEmptyKey(), nullptr); }
295 
296   static T getTombstoneKey() {
297     return T(ValueTypeInfo::getTombstoneKey(), nullptr);
298   }
299 
300   static unsigned getHashValue(T val) {
301     return ValueTypeInfo::getHashValue(val);
302   }
303 
304   static bool isEqual(T lhs, T rhs) { return ValueTypeInfo::isEqual(lhs, rhs); }
305 };
306 
307 } // namespace llvm
308 
309 #endif
310