1 //===- DialectInterface.h - IR Dialect Interfaces ---------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_IR_DIALECTINTERFACE_H 10 #define MLIR_IR_DIALECTINTERFACE_H 11 12 #include "mlir/Support/TypeID.h" 13 #include "llvm/ADT/DenseSet.h" 14 #include "llvm/ADT/STLExtras.h" 15 16 namespace mlir { 17 class Dialect; 18 class MLIRContext; 19 class Operation; 20 21 //===----------------------------------------------------------------------===// 22 // DialectInterface 23 //===----------------------------------------------------------------------===// 24 namespace detail { 25 /// The base class used for all derived interface types. This class provides 26 /// utilities necessary for registration. 27 template <typename ConcreteType, typename BaseT> 28 class DialectInterfaceBase : public BaseT { 29 public: 30 using Base = DialectInterfaceBase<ConcreteType, BaseT>; 31 32 /// Get a unique id for the derived interface type. getInterfaceID()33 static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); } 34 35 protected: DialectInterfaceBase(Dialect * dialect)36 DialectInterfaceBase(Dialect *dialect) : BaseT(dialect, getInterfaceID()) {} 37 }; 38 } // namespace detail 39 40 /// This class represents an interface overridden for a single dialect. 41 class DialectInterface { 42 public: 43 virtual ~DialectInterface(); 44 45 /// The base class used for all derived interface types. This class provides 46 /// utilities necessary for registration. 47 template <typename ConcreteType> 48 using Base = detail::DialectInterfaceBase<ConcreteType, DialectInterface>; 49 50 /// Return the dialect that this interface represents. getDialect()51 Dialect *getDialect() const { return dialect; } 52 53 /// Return the derived interface id. getID()54 TypeID getID() const { return interfaceID; } 55 56 protected: DialectInterface(Dialect * dialect,TypeID id)57 DialectInterface(Dialect *dialect, TypeID id) 58 : dialect(dialect), interfaceID(id) {} 59 60 private: 61 /// The dialect that represents this interface. 62 Dialect *dialect; 63 64 /// The unique identifier for the derived interface type. 65 TypeID interfaceID; 66 }; 67 68 //===----------------------------------------------------------------------===// 69 // DialectInterfaceCollection 70 //===----------------------------------------------------------------------===// 71 72 namespace detail { 73 /// This class is the base class for a collection of instances for a specific 74 /// interface kind. 75 class DialectInterfaceCollectionBase { 76 /// DenseMap info for dialect interfaces that allows lookup by the dialect. 77 struct InterfaceKeyInfo : public DenseMapInfo<const DialectInterface *> { 78 using DenseMapInfo<const DialectInterface *>::isEqual; 79 getHashValueInterfaceKeyInfo80 static unsigned getHashValue(Dialect *key) { return llvm::hash_value(key); } getHashValueInterfaceKeyInfo81 static unsigned getHashValue(const DialectInterface *key) { 82 return getHashValue(key->getDialect()); 83 } 84 isEqualInterfaceKeyInfo85 static bool isEqual(Dialect *lhs, const DialectInterface *rhs) { 86 if (rhs == getEmptyKey() || rhs == getTombstoneKey()) 87 return false; 88 return lhs == rhs->getDialect(); 89 } 90 }; 91 92 /// A set of registered dialect interface instances. 93 using InterfaceSetT = DenseSet<const DialectInterface *, InterfaceKeyInfo>; 94 using InterfaceVectorT = std::vector<const DialectInterface *>; 95 96 public: 97 DialectInterfaceCollectionBase(MLIRContext *ctx, TypeID interfaceKind); 98 virtual ~DialectInterfaceCollectionBase(); 99 100 protected: 101 /// Get the interface for the dialect of given operation, or null if one 102 /// is not registered. 103 const DialectInterface *getInterfaceFor(Operation *op) const; 104 105 /// Get the interface for the given dialect. getInterfaceFor(Dialect * dialect)106 const DialectInterface *getInterfaceFor(Dialect *dialect) const { 107 auto it = interfaces.find_as(dialect); 108 return it == interfaces.end() ? nullptr : *it; 109 } 110 111 /// An iterator class that iterates the held interface objects of the given 112 /// derived interface type. 113 template <typename InterfaceT> 114 struct iterator 115 : public llvm::mapped_iterator_base<iterator<InterfaceT>, 116 InterfaceVectorT::const_iterator, 117 const InterfaceT &> { 118 using llvm::mapped_iterator_base<iterator<InterfaceT>, 119 InterfaceVectorT::const_iterator, 120 const InterfaceT &>::mapped_iterator_base; 121 122 /// Map the element to the iterator result type. mapElementiterator123 const InterfaceT &mapElement(const DialectInterface *interface) const { 124 return *static_cast<const InterfaceT *>(interface); 125 } 126 }; 127 128 /// Iterator access to the held interfaces. 129 template <typename InterfaceT> interface_begin()130 iterator<InterfaceT> interface_begin() const { 131 return iterator<InterfaceT>(orderedInterfaces.begin()); 132 } 133 template <typename InterfaceT> interface_end()134 iterator<InterfaceT> interface_end() const { 135 return iterator<InterfaceT>(orderedInterfaces.end()); 136 } 137 138 private: 139 /// A set of registered dialect interface instances. 140 InterfaceSetT interfaces; 141 /// An ordered list of the registered interface instances, necessary for 142 /// deterministic iteration. 143 // NOTE: SetVector does not provide find access, so it can't be used here. 144 InterfaceVectorT orderedInterfaces; 145 }; 146 } // namespace detail 147 148 /// A collection of dialect interfaces within a context, for a given concrete 149 /// interface type. 150 template <typename InterfaceType> 151 class DialectInterfaceCollection 152 : public detail::DialectInterfaceCollectionBase { 153 public: 154 using Base = DialectInterfaceCollection<InterfaceType>; 155 156 /// Collect the registered dialect interfaces within the provided context. DialectInterfaceCollection(MLIRContext * ctx)157 DialectInterfaceCollection(MLIRContext *ctx) 158 : detail::DialectInterfaceCollectionBase( 159 ctx, InterfaceType::getInterfaceID()) {} 160 161 /// Get the interface for a given object, or null if one is not registered. 162 /// The object may be a dialect or an operation instance. 163 template <typename Object> getInterfaceFor(Object * obj)164 const InterfaceType *getInterfaceFor(Object *obj) const { 165 return static_cast<const InterfaceType *>( 166 detail::DialectInterfaceCollectionBase::getInterfaceFor(obj)); 167 } 168 169 /// Iterator access to the held interfaces. 170 using iterator = 171 detail::DialectInterfaceCollectionBase::iterator<InterfaceType>; begin()172 iterator begin() const { return interface_begin<InterfaceType>(); } end()173 iterator end() const { return interface_end<InterfaceType>(); } 174 175 private: 176 using detail::DialectInterfaceCollectionBase::interface_begin; 177 using detail::DialectInterfaceCollectionBase::interface_end; 178 }; 179 180 } // namespace mlir 181 182 #endif 183