1 //===- StorageUniquerSupport.h - MLIR Storage Uniquer Utilities -*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines utility classes for interfacing with StorageUniquer.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #ifndef MLIR_IR_STORAGEUNIQUERSUPPORT_H
14 #define MLIR_IR_STORAGEUNIQUERSUPPORT_H
15
16 #include "mlir/Support/InterfaceSupport.h"
17 #include "mlir/Support/LogicalResult.h"
18 #include "mlir/Support/StorageUniquer.h"
19 #include "mlir/Support/TypeID.h"
20 #include "llvm/ADT/FunctionExtras.h"
21
22 namespace mlir {
23 class InFlightDiagnostic;
24 class Location;
25 class MLIRContext;
26
27 namespace detail {
28 /// Utility method to generate a callback that can be used to generate a
29 /// diagnostic when checking the construction invariants of a storage object.
30 /// This is defined out-of-line to avoid the need to include Location.h.
31 llvm::unique_function<InFlightDiagnostic()>
32 getDefaultDiagnosticEmitFn(MLIRContext *ctx);
33 llvm::unique_function<InFlightDiagnostic()>
34 getDefaultDiagnosticEmitFn(const Location &loc);
35
36 //===----------------------------------------------------------------------===//
37 // StorageUserTraitBase
38 //===----------------------------------------------------------------------===//
39
40 /// Helper class for implementing traits for storage classes. Clients are not
41 /// expected to interact with this directly, so its members are all protected.
42 template <typename ConcreteType, template <typename> class TraitType>
43 class StorageUserTraitBase {
44 protected:
45 /// Return the derived instance.
getInstance()46 ConcreteType getInstance() const {
47 // We have to cast up to the trait type, then to the concrete type because
48 // the concrete type will multiply derive from the (content free) TraitBase
49 // class, and we need to be able to disambiguate the path for the C++
50 // compiler.
51 auto *trait = static_cast<const TraitType<ConcreteType> *>(this);
52 return *static_cast<const ConcreteType *>(trait);
53 }
54 };
55
56 namespace StorageUserTrait {
57 /// This trait is used to determine if a storage user, like Type, is mutable
58 /// or not. A storage user is mutable if ImplType of the derived class defines
59 /// a `mutate` function with a proper signature. Note that this trait is not
60 /// supposed to be used publicly. Users should use alias names like
61 /// `TypeTrait::IsMutable` instead.
62 template <typename ConcreteType>
63 struct IsMutable : public StorageUserTraitBase<ConcreteType, IsMutable> {};
64 } // namespace StorageUserTrait
65
66 //===----------------------------------------------------------------------===//
67 // StorageUserBase
68 //===----------------------------------------------------------------------===//
69
70 namespace storage_user_base_impl {
71 /// Returns true if this given Trait ID matches the IDs of any of the provided
72 /// trait types `Traits`.
73 template <template <typename T> class... Traits>
hasTrait(TypeID traitID)74 bool hasTrait(TypeID traitID) {
75 TypeID traitIDs[] = {TypeID::get<Traits>()...};
76 for (unsigned i = 0, e = sizeof...(Traits); i != e; ++i)
77 if (traitIDs[i] == traitID)
78 return true;
79 return false;
80 }
81
82 // We specialize for the empty case to not define an empty array.
83 template <>
hasTrait(TypeID traitID)84 inline bool hasTrait(TypeID traitID) {
85 return false;
86 }
87 } // namespace storage_user_base_impl
88
89 /// Utility class for implementing users of storage classes uniqued by a
90 /// StorageUniquer. Clients are not expected to interact with this class
91 /// directly.
92 template <typename ConcreteT, typename BaseT, typename StorageT,
93 typename UniquerT, template <typename T> class... Traits>
94 class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
95 public:
96 using BaseT::BaseT;
97
98 /// Utility declarations for the concrete attribute class.
99 using Base = StorageUserBase<ConcreteT, BaseT, StorageT, UniquerT, Traits...>;
100 using ImplType = StorageT;
101 using HasTraitFn = bool (*)(TypeID);
102
103 /// Return a unique identifier for the concrete type.
getTypeID()104 static TypeID getTypeID() { return TypeID::get<ConcreteT>(); }
105
106 /// Provide an implementation of 'classof' that compares the type id of the
107 /// provided value with that of the concrete type.
108 template <typename T>
classof(T val)109 static bool classof(T val) {
110 static_assert(std::is_convertible<ConcreteT, T>::value,
111 "casting from a non-convertible type");
112 return val.getTypeID() == getTypeID();
113 }
114
115 /// Returns an interface map for the interfaces registered to this storage
116 /// user. This should not be used directly.
getInterfaceMap()117 static detail::InterfaceMap getInterfaceMap() {
118 return detail::InterfaceMap::template get<Traits<ConcreteT>...>();
119 }
120
121 /// Returns the function that returns true if the given Trait ID matches the
122 /// IDs of any of the traits defined by the storage user.
getHasTraitFn()123 static HasTraitFn getHasTraitFn() {
124 return [](TypeID id) {
125 return storage_user_base_impl::hasTrait<Traits...>(id);
126 };
127 }
128
129 /// Attach the given models as implementations of the corresponding interfaces
130 /// for the concrete storage user class. The type must be registered with the
131 /// context, i.e. the dialect to which the type belongs must be loaded. The
132 /// call will abort otherwise.
133 template <typename... IfaceModels>
attachInterface(MLIRContext & context)134 static void attachInterface(MLIRContext &context) {
135 typename ConcreteT::AbstractTy *abstract =
136 ConcreteT::AbstractTy::lookupMutable(TypeID::get<ConcreteT>(),
137 &context);
138 if (!abstract)
139 llvm::report_fatal_error("Registering an interface for an attribute/type "
140 "that is not itself registered.");
141 (void)std::initializer_list<int>{
142 (checkInterfaceTarget<IfaceModels>(), 0)...};
143 abstract->interfaceMap.template insert<IfaceModels...>();
144 }
145
146 /// Get or create a new ConcreteT instance within the ctx. This
147 /// function is guaranteed to return a non null object and will assert if
148 /// the arguments provided are invalid.
149 template <typename... Args>
get(MLIRContext * ctx,Args...args)150 static ConcreteT get(MLIRContext *ctx, Args... args) {
151 // Ensure that the invariants are correct for construction.
152 assert(
153 succeeded(ConcreteT::verify(getDefaultDiagnosticEmitFn(ctx), args...)));
154 return UniquerT::template get<ConcreteT>(ctx, args...);
155 }
156
157 /// Get or create a new ConcreteT instance within the ctx, defined at
158 /// the given, potentially unknown, location. If the arguments provided are
159 /// invalid, errors are emitted using the provided location and a null object
160 /// is returned.
161 template <typename... Args>
getChecked(const Location & loc,Args...args)162 static ConcreteT getChecked(const Location &loc, Args... args) {
163 return ConcreteT::getChecked(getDefaultDiagnosticEmitFn(loc), args...);
164 }
165
166 /// Get or create a new ConcreteT instance within the ctx. If the arguments
167 /// provided are invalid, errors are emitted using the provided `emitError`
168 /// and a null object is returned.
169 template <typename... Args>
getChecked(function_ref<InFlightDiagnostic ()> emitErrorFn,MLIRContext * ctx,Args...args)170 static ConcreteT getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
171 MLIRContext *ctx, Args... args) {
172 // If the construction invariants fail then we return a null attribute.
173 if (failed(ConcreteT::verify(emitErrorFn, args...)))
174 return ConcreteT();
175 return UniquerT::template get<ConcreteT>(ctx, args...);
176 }
177
178 /// Get an instance of the concrete type from a void pointer.
getFromOpaquePointer(const void * ptr)179 static ConcreteT getFromOpaquePointer(const void *ptr) {
180 return ConcreteT((const typename BaseT::ImplType *)ptr);
181 }
182
183 protected:
184 /// Mutate the current storage instance. This will not change the unique key.
185 /// The arguments are forwarded to 'ConcreteT::mutate'.
186 template <typename... Args>
mutate(Args &&...args)187 LogicalResult mutate(Args &&...args) {
188 static_assert(std::is_base_of<StorageUserTrait::IsMutable<ConcreteT>,
189 ConcreteT>::value,
190 "The `mutate` function expects mutable trait "
191 "(e.g. TypeTrait::IsMutable) to be attached on parent.");
192 return UniquerT::template mutate<ConcreteT>(this->getContext(), getImpl(),
193 std::forward<Args>(args)...);
194 }
195
196 /// Default implementation that just returns success.
197 template <typename... Args>
verify(Args...args)198 static LogicalResult verify(Args... args) {
199 return success();
200 }
201
202 /// Utility for easy access to the storage instance.
getImpl()203 ImplType *getImpl() const { return static_cast<ImplType *>(this->impl); }
204
205 private:
206 /// Trait to check if T provides a 'ConcreteEntity' type alias.
207 template <typename T>
208 using has_concrete_entity_t = typename T::ConcreteEntity;
209
210 /// A struct-wrapped type alias to T::ConcreteEntity if provided and to
211 /// ConcreteT otherwise. This is akin to std::conditional but doesn't fail on
212 /// the missing typedef. Useful for checking if the interface is targeting the
213 /// right class.
214 template <typename T,
215 bool = llvm::is_detected<has_concrete_entity_t, T>::value>
216 struct IfaceTargetOrConcreteT {
217 using type = typename T::ConcreteEntity;
218 };
219 template <typename T>
220 struct IfaceTargetOrConcreteT<T, false> {
221 using type = ConcreteT;
222 };
223
224 /// A hook for static assertion that the external interface model T is
225 /// targeting a base class of the concrete attribute/type. The model can also
226 /// be a fallback model that works for every attribute/type.
227 template <typename T>
228 static void checkInterfaceTarget() {
229 static_assert(std::is_base_of<typename IfaceTargetOrConcreteT<T>::type,
230 ConcreteT>::value,
231 "attaching an interface to the wrong attribute/type kind");
232 }
233 };
234 } // namespace detail
235 } // namespace mlir
236
237 #endif
238