1 //===- DialectInterface.h - IR Dialect Interfaces ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_DIALECTINTERFACE_H
10 #define MLIR_IR_DIALECTINTERFACE_H
11 
12 #include "mlir/Support/TypeID.h"
13 #include "llvm/ADT/DenseSet.h"
14 #include "llvm/ADT/STLExtras.h"
15 
16 namespace mlir {
17 class Dialect;
18 class MLIRContext;
19 class Operation;
20 
21 //===----------------------------------------------------------------------===//
22 // DialectInterface
23 //===----------------------------------------------------------------------===//
24 namespace detail {
25 /// The base class used for all derived interface types. This class provides
26 /// utilities necessary for registration.
27 template <typename ConcreteType, typename BaseT>
28 class DialectInterfaceBase : public BaseT {
29 public:
30   using Base = DialectInterfaceBase<ConcreteType, BaseT>;
31 
32   /// Get a unique id for the derived interface type.
getInterfaceID()33   static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); }
34 
35 protected:
DialectInterfaceBase(Dialect * dialect)36   DialectInterfaceBase(Dialect *dialect) : BaseT(dialect, getInterfaceID()) {}
37 };
38 } // namespace detail
39 
40 /// This class represents an interface overridden for a single dialect.
41 class DialectInterface {
42 public:
43   virtual ~DialectInterface();
44 
45   /// The base class used for all derived interface types. This class provides
46   /// utilities necessary for registration.
47   template <typename ConcreteType>
48   using Base = detail::DialectInterfaceBase<ConcreteType, DialectInterface>;
49 
50   /// Return the dialect that this interface represents.
getDialect()51   Dialect *getDialect() const { return dialect; }
52 
53   /// Return the derived interface id.
getID()54   TypeID getID() const { return interfaceID; }
55 
56 protected:
DialectInterface(Dialect * dialect,TypeID id)57   DialectInterface(Dialect *dialect, TypeID id)
58       : dialect(dialect), interfaceID(id) {}
59 
60 private:
61   /// The dialect that represents this interface.
62   Dialect *dialect;
63 
64   /// The unique identifier for the derived interface type.
65   TypeID interfaceID;
66 };
67 
68 //===----------------------------------------------------------------------===//
69 // DialectInterfaceCollection
70 //===----------------------------------------------------------------------===//
71 
72 namespace detail {
73 /// This class is the base class for a collection of instances for a specific
74 /// interface kind.
75 class DialectInterfaceCollectionBase {
76   /// DenseMap info for dialect interfaces that allows lookup by the dialect.
77   struct InterfaceKeyInfo : public DenseMapInfo<const DialectInterface *> {
78     using DenseMapInfo<const DialectInterface *>::isEqual;
79 
getHashValueInterfaceKeyInfo80     static unsigned getHashValue(Dialect *key) { return llvm::hash_value(key); }
getHashValueInterfaceKeyInfo81     static unsigned getHashValue(const DialectInterface *key) {
82       return getHashValue(key->getDialect());
83     }
84 
isEqualInterfaceKeyInfo85     static bool isEqual(Dialect *lhs, const DialectInterface *rhs) {
86       if (rhs == getEmptyKey() || rhs == getTombstoneKey())
87         return false;
88       return lhs == rhs->getDialect();
89     }
90   };
91 
92   /// A set of registered dialect interface instances.
93   using InterfaceSetT = DenseSet<const DialectInterface *, InterfaceKeyInfo>;
94   using InterfaceVectorT = std::vector<const DialectInterface *>;
95 
96 public:
97   DialectInterfaceCollectionBase(MLIRContext *ctx, TypeID interfaceKind);
98   virtual ~DialectInterfaceCollectionBase();
99 
100 protected:
101   /// Get the interface for the dialect of given operation, or null if one
102   /// is not registered.
103   const DialectInterface *getInterfaceFor(Operation *op) const;
104 
105   /// Get the interface for the given dialect.
getInterfaceFor(Dialect * dialect)106   const DialectInterface *getInterfaceFor(Dialect *dialect) const {
107     auto it = interfaces.find_as(dialect);
108     return it == interfaces.end() ? nullptr : *it;
109   }
110 
111   /// An iterator class that iterates the held interface objects of the given
112   /// derived interface type.
113   template <typename InterfaceT>
114   struct iterator
115       : public llvm::mapped_iterator_base<iterator<InterfaceT>,
116                                           InterfaceVectorT::const_iterator,
117                                           const InterfaceT &> {
118     using llvm::mapped_iterator_base<iterator<InterfaceT>,
119                                      InterfaceVectorT::const_iterator,
120                                      const InterfaceT &>::mapped_iterator_base;
121 
122     /// Map the element to the iterator result type.
mapElementiterator123     const InterfaceT &mapElement(const DialectInterface *interface) const {
124       return *static_cast<const InterfaceT *>(interface);
125     }
126   };
127 
128   /// Iterator access to the held interfaces.
129   template <typename InterfaceT>
interface_begin()130   iterator<InterfaceT> interface_begin() const {
131     return iterator<InterfaceT>(orderedInterfaces.begin());
132   }
133   template <typename InterfaceT>
interface_end()134   iterator<InterfaceT> interface_end() const {
135     return iterator<InterfaceT>(orderedInterfaces.end());
136   }
137 
138 private:
139   /// A set of registered dialect interface instances.
140   InterfaceSetT interfaces;
141   /// An ordered list of the registered interface instances, necessary for
142   /// deterministic iteration.
143   // NOTE: SetVector does not provide find access, so it can't be used here.
144   InterfaceVectorT orderedInterfaces;
145 };
146 } // namespace detail
147 
148 /// A collection of dialect interfaces within a context, for a given concrete
149 /// interface type.
150 template <typename InterfaceType>
151 class DialectInterfaceCollection
152     : public detail::DialectInterfaceCollectionBase {
153 public:
154   using Base = DialectInterfaceCollection<InterfaceType>;
155 
156   /// Collect the registered dialect interfaces within the provided context.
DialectInterfaceCollection(MLIRContext * ctx)157   DialectInterfaceCollection(MLIRContext *ctx)
158       : detail::DialectInterfaceCollectionBase(
159             ctx, InterfaceType::getInterfaceID()) {}
160 
161   /// Get the interface for a given object, or null if one is not registered.
162   /// The object may be a dialect or an operation instance.
163   template <typename Object>
getInterfaceFor(Object * obj)164   const InterfaceType *getInterfaceFor(Object *obj) const {
165     return static_cast<const InterfaceType *>(
166         detail::DialectInterfaceCollectionBase::getInterfaceFor(obj));
167   }
168 
169   /// Iterator access to the held interfaces.
170   using iterator =
171       detail::DialectInterfaceCollectionBase::iterator<InterfaceType>;
begin()172   iterator begin() const { return interface_begin<InterfaceType>(); }
end()173   iterator end() const { return interface_end<InterfaceType>(); }
174 
175 private:
176   using detail::DialectInterfaceCollectionBase::interface_begin;
177   using detail::DialectInterfaceCollectionBase::interface_end;
178 };
179 
180 } // namespace mlir
181 
182 #endif
183