xref: /llvm-project-15.0.7/mlir/lib/IR/Dialect.cpp (revision 0570de73)
1 //===- Dialect.cpp - Dialect implementation -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Dialect.h"
10 #include "mlir/IR/Diagnostics.h"
11 #include "mlir/IR/DialectHooks.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/IR/DialectInterface.h"
14 #include "mlir/IR/MLIRContext.h"
15 #include "mlir/IR/Operation.h"
16 #include "llvm/ADT/MapVector.h"
17 #include "llvm/ADT/Twine.h"
18 #include "llvm/Support/ManagedStatic.h"
19 #include "llvm/Support/Regex.h"
20 
21 using namespace mlir;
22 using namespace detail;
23 
24 DialectAsmParser::~DialectAsmParser() {}
25 
26 //===----------------------------------------------------------------------===//
27 // Dialect Registration
28 //===----------------------------------------------------------------------===//
29 
30 /// Registry for all dialect allocation functions.
31 static llvm::ManagedStatic<
32     llvm::MapVector<const ClassID *, DialectAllocatorFunction>>
33     dialectRegistry;
34 
35 /// Registry for functions that set dialect hooks.
36 static llvm::ManagedStatic<llvm::MapVector<const ClassID *, DialectHooksSetter>>
37     dialectHooksRegistry;
38 
39 void Dialect::registerDialectAllocator(
40     const ClassID *classId, const DialectAllocatorFunction &function) {
41   assert(function &&
42          "Attempting to register an empty dialect initialize function");
43   dialectRegistry->insert({classId, function});
44 }
45 
46 /// Registers a function to set specific hooks for a specific dialect, typically
47 /// used through the DialectHooksRegistration template.
48 void DialectHooks::registerDialectHooksSetter(
49     const ClassID *classId, const DialectHooksSetter &function) {
50   assert(
51       function &&
52       "Attempting to register an empty dialect hooks initialization function");
53 
54   dialectHooksRegistry->insert({classId, function});
55 }
56 
57 /// Registers all dialects and hooks from the global registries with the
58 /// specified MLIRContext.
59 void mlir::registerAllDialects(MLIRContext *context) {
60   for (const auto &it : *dialectRegistry)
61     it.second(context);
62   for (const auto &it : *dialectHooksRegistry) {
63     it.second(context);
64   }
65 }
66 
67 //===----------------------------------------------------------------------===//
68 // Dialect
69 //===----------------------------------------------------------------------===//
70 
71 Dialect::Dialect(StringRef name, MLIRContext *context)
72     : name(name), context(context) {
73   assert(isValidNamespace(name) && "invalid dialect namespace");
74   registerDialect(context);
75 }
76 
77 Dialect::~Dialect() {}
78 
79 /// Verify an attribute from this dialect on the argument at 'argIndex' for
80 /// the region at 'regionIndex' on the given operation. Returns failure if
81 /// the verification failed, success otherwise. This hook may optionally be
82 /// invoked from any operation containing a region.
83 LogicalResult Dialect::verifyRegionArgAttribute(Operation *, unsigned, unsigned,
84                                                 NamedAttribute) {
85   return success();
86 }
87 
88 /// Verify an attribute from this dialect on the result at 'resultIndex' for
89 /// the region at 'regionIndex' on the given operation. Returns failure if
90 /// the verification failed, success otherwise. This hook may optionally be
91 /// invoked from any operation containing a region.
92 LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned,
93                                                    unsigned, NamedAttribute) {
94   return success();
95 }
96 
97 /// Parse an attribute registered to this dialect.
98 Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type) const {
99   parser.emitError(parser.getNameLoc())
100       << "dialect '" << getNamespace()
101       << "' provides no attribute parsing hook";
102   return Attribute();
103 }
104 
105 /// Parse a type registered to this dialect.
106 Type Dialect::parseType(DialectAsmParser &parser) const {
107   // If this dialect allows unknown types, then represent this with OpaqueType.
108   if (allowsUnknownTypes()) {
109     auto ns = Identifier::get(getNamespace(), getContext());
110     return OpaqueType::get(ns, parser.getFullSymbolSpec(), getContext());
111   }
112 
113   parser.emitError(parser.getNameLoc())
114       << "dialect '" << getNamespace() << "' provides no type parsing hook";
115   return Type();
116 }
117 
118 /// Utility function that returns if the given string is a valid dialect
119 /// namespace.
120 bool Dialect::isValidNamespace(StringRef str) {
121   if (str.empty())
122     return true;
123   llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$");
124   return dialectNameRegex.match(str);
125 }
126 
127 /// Register a set of dialect interfaces with this dialect instance.
128 void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
129   auto it = registeredInterfaces.try_emplace(interface->getID(),
130                                              std::move(interface));
131   (void)it;
132   assert(it.second && "interface kind has already been registered");
133 }
134 
135 //===----------------------------------------------------------------------===//
136 // Dialect Interface
137 //===----------------------------------------------------------------------===//
138 
139 DialectInterface::~DialectInterface() {}
140 
141 DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
142     MLIRContext *ctx, ClassID *interfaceKind) {
143   for (auto *dialect : ctx->getRegisteredDialects()) {
144     if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
145       interfaces.insert(interface);
146       orderedInterfaces.push_back(interface);
147     }
148   }
149 }
150 
151 DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() {}
152 
153 /// Get the interface for the dialect of given operation, or null if one
154 /// is not registered.
155 const DialectInterface *
156 DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const {
157   return getInterfaceFor(op->getDialect());
158 }
159