1 //===- StorageUniquerSupport.h - MLIR Storage Uniquer Utilities -*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines utility classes for interfacing with StorageUniquer.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_STORAGEUNIQUERSUPPORT_H
14 #define MLIR_IR_STORAGEUNIQUERSUPPORT_H
15 
16 #include "mlir/Support/InterfaceSupport.h"
17 #include "mlir/Support/LogicalResult.h"
18 #include "mlir/Support/StorageUniquer.h"
19 #include "mlir/Support/TypeID.h"
20 #include "llvm/ADT/FunctionExtras.h"
21 
22 namespace mlir {
23 class InFlightDiagnostic;
24 class Location;
25 class MLIRContext;
26 
27 namespace detail {
28 /// Utility method to generate a callback that can be used to generate a
29 /// diagnostic when checking the construction invariants of a storage object.
30 /// This is defined out-of-line to avoid the need to include Location.h.
31 llvm::unique_function<InFlightDiagnostic()>
32 getDefaultDiagnosticEmitFn(MLIRContext *ctx);
33 llvm::unique_function<InFlightDiagnostic()>
34 getDefaultDiagnosticEmitFn(const Location &loc);
35 
36 //===----------------------------------------------------------------------===//
37 // StorageUserTraitBase
38 //===----------------------------------------------------------------------===//
39 
40 /// Helper class for implementing traits for storage classes. Clients are not
41 /// expected to interact with this directly, so its members are all protected.
42 template <typename ConcreteType, template <typename> class TraitType>
43 class StorageUserTraitBase {
44 protected:
45   /// Return the derived instance.
getInstance()46   ConcreteType getInstance() const {
47     // We have to cast up to the trait type, then to the concrete type because
48     // the concrete type will multiply derive from the (content free) TraitBase
49     // class, and we need to be able to disambiguate the path for the C++
50     // compiler.
51     auto *trait = static_cast<const TraitType<ConcreteType> *>(this);
52     return *static_cast<const ConcreteType *>(trait);
53   }
54 };
55 
56 namespace StorageUserTrait {
57 /// This trait is used to determine if a storage user, like Type, is mutable
58 /// or not. A storage user is mutable if ImplType of the derived class defines
59 /// a `mutate` function with a proper signature. Note that this trait is not
60 /// supposed to be used publicly. Users should use alias names like
61 /// `TypeTrait::IsMutable` instead.
62 template <typename ConcreteType>
63 struct IsMutable : public StorageUserTraitBase<ConcreteType, IsMutable> {};
64 } // namespace StorageUserTrait
65 
66 //===----------------------------------------------------------------------===//
67 // StorageUserBase
68 //===----------------------------------------------------------------------===//
69 
70 namespace storage_user_base_impl {
71 /// Returns true if this given Trait ID matches the IDs of any of the provided
72 /// trait types `Traits`.
73 template <template <typename T> class... Traits>
hasTrait(TypeID traitID)74 bool hasTrait(TypeID traitID) {
75   TypeID traitIDs[] = {TypeID::get<Traits>()...};
76   for (unsigned i = 0, e = sizeof...(Traits); i != e; ++i)
77     if (traitIDs[i] == traitID)
78       return true;
79   return false;
80 }
81 
82 // We specialize for the empty case to not define an empty array.
83 template <>
hasTrait(TypeID traitID)84 inline bool hasTrait(TypeID traitID) {
85   return false;
86 }
87 } // namespace storage_user_base_impl
88 
89 /// Utility class for implementing users of storage classes uniqued by a
90 /// StorageUniquer. Clients are not expected to interact with this class
91 /// directly.
92 template <typename ConcreteT, typename BaseT, typename StorageT,
93           typename UniquerT, template <typename T> class... Traits>
94 class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
95 public:
96   using BaseT::BaseT;
97 
98   /// Utility declarations for the concrete attribute class.
99   using Base = StorageUserBase<ConcreteT, BaseT, StorageT, UniquerT, Traits...>;
100   using ImplType = StorageT;
101   using HasTraitFn = bool (*)(TypeID);
102 
103   /// Return a unique identifier for the concrete type.
getTypeID()104   static TypeID getTypeID() { return TypeID::get<ConcreteT>(); }
105 
106   /// Provide an implementation of 'classof' that compares the type id of the
107   /// provided value with that of the concrete type.
108   template <typename T>
classof(T val)109   static bool classof(T val) {
110     static_assert(std::is_convertible<ConcreteT, T>::value,
111                   "casting from a non-convertible type");
112     return val.getTypeID() == getTypeID();
113   }
114 
115   /// Returns an interface map for the interfaces registered to this storage
116   /// user. This should not be used directly.
getInterfaceMap()117   static detail::InterfaceMap getInterfaceMap() {
118     return detail::InterfaceMap::template get<Traits<ConcreteT>...>();
119   }
120 
121   /// Returns the function that returns true if the given Trait ID matches the
122   /// IDs of any of the traits defined by the storage user.
getHasTraitFn()123   static HasTraitFn getHasTraitFn() {
124     return [](TypeID id) {
125       return storage_user_base_impl::hasTrait<Traits...>(id);
126     };
127   }
128 
129   /// Attach the given models as implementations of the corresponding interfaces
130   /// for the concrete storage user class. The type must be registered with the
131   /// context, i.e. the dialect to which the type belongs must be loaded. The
132   /// call will abort otherwise.
133   template <typename... IfaceModels>
attachInterface(MLIRContext & context)134   static void attachInterface(MLIRContext &context) {
135     typename ConcreteT::AbstractTy *abstract =
136         ConcreteT::AbstractTy::lookupMutable(TypeID::get<ConcreteT>(),
137                                              &context);
138     if (!abstract)
139       llvm::report_fatal_error("Registering an interface for an attribute/type "
140                                "that is not itself registered.");
141     (void)std::initializer_list<int>{
142         (checkInterfaceTarget<IfaceModels>(), 0)...};
143     abstract->interfaceMap.template insert<IfaceModels...>();
144   }
145 
146   /// Get or create a new ConcreteT instance within the ctx. This
147   /// function is guaranteed to return a non null object and will assert if
148   /// the arguments provided are invalid.
149   template <typename... Args>
get(MLIRContext * ctx,Args...args)150   static ConcreteT get(MLIRContext *ctx, Args... args) {
151     // Ensure that the invariants are correct for construction.
152     assert(
153         succeeded(ConcreteT::verify(getDefaultDiagnosticEmitFn(ctx), args...)));
154     return UniquerT::template get<ConcreteT>(ctx, args...);
155   }
156 
157   /// Get or create a new ConcreteT instance within the ctx, defined at
158   /// the given, potentially unknown, location. If the arguments provided are
159   /// invalid, errors are emitted using the provided location and a null object
160   /// is returned.
161   template <typename... Args>
getChecked(const Location & loc,Args...args)162   static ConcreteT getChecked(const Location &loc, Args... args) {
163     return ConcreteT::getChecked(getDefaultDiagnosticEmitFn(loc), args...);
164   }
165 
166   /// Get or create a new ConcreteT instance within the ctx. If the arguments
167   /// provided are invalid, errors are emitted using the provided `emitError`
168   /// and a null object is returned.
169   template <typename... Args>
getChecked(function_ref<InFlightDiagnostic ()> emitErrorFn,MLIRContext * ctx,Args...args)170   static ConcreteT getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
171                               MLIRContext *ctx, Args... args) {
172     // If the construction invariants fail then we return a null attribute.
173     if (failed(ConcreteT::verify(emitErrorFn, args...)))
174       return ConcreteT();
175     return UniquerT::template get<ConcreteT>(ctx, args...);
176   }
177 
178   /// Get an instance of the concrete type from a void pointer.
getFromOpaquePointer(const void * ptr)179   static ConcreteT getFromOpaquePointer(const void *ptr) {
180     return ConcreteT((const typename BaseT::ImplType *)ptr);
181   }
182 
183 protected:
184   /// Mutate the current storage instance. This will not change the unique key.
185   /// The arguments are forwarded to 'ConcreteT::mutate'.
186   template <typename... Args>
mutate(Args &&...args)187   LogicalResult mutate(Args &&...args) {
188     static_assert(std::is_base_of<StorageUserTrait::IsMutable<ConcreteT>,
189                                   ConcreteT>::value,
190                   "The `mutate` function expects mutable trait "
191                   "(e.g. TypeTrait::IsMutable) to be attached on parent.");
192     return UniquerT::template mutate<ConcreteT>(this->getContext(), getImpl(),
193                                                 std::forward<Args>(args)...);
194   }
195 
196   /// Default implementation that just returns success.
197   template <typename... Args>
verify(Args...args)198   static LogicalResult verify(Args... args) {
199     return success();
200   }
201 
202   /// Utility for easy access to the storage instance.
getImpl()203   ImplType *getImpl() const { return static_cast<ImplType *>(this->impl); }
204 
205 private:
206   /// Trait to check if T provides a 'ConcreteEntity' type alias.
207   template <typename T>
208   using has_concrete_entity_t = typename T::ConcreteEntity;
209 
210   /// A struct-wrapped type alias to T::ConcreteEntity if provided and to
211   /// ConcreteT otherwise. This is akin to std::conditional but doesn't fail on
212   /// the missing typedef. Useful for checking if the interface is targeting the
213   /// right class.
214   template <typename T,
215             bool = llvm::is_detected<has_concrete_entity_t, T>::value>
216   struct IfaceTargetOrConcreteT {
217     using type = typename T::ConcreteEntity;
218   };
219   template <typename T>
220   struct IfaceTargetOrConcreteT<T, false> {
221     using type = ConcreteT;
222   };
223 
224   /// A hook for static assertion that the external interface model T is
225   /// targeting a base class of the concrete attribute/type. The model can also
226   /// be a fallback model that works for every attribute/type.
227   template <typename T>
228   static void checkInterfaceTarget() {
229     static_assert(std::is_base_of<typename IfaceTargetOrConcreteT<T>::type,
230                                   ConcreteT>::value,
231                   "attaching an interface to the wrong attribute/type kind");
232   }
233 };
234 } // namespace detail
235 } // namespace mlir
236 
237 #endif
238