xref: /llvm-project-15.0.7/mlir/lib/IR/Dialect.cpp (revision d3dfd8ce)
1 //===- Dialect.cpp - Dialect implementation -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Dialect.h"
10 #include "mlir/IR/Diagnostics.h"
11 #include "mlir/IR/DialectHooks.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/IR/DialectInterface.h"
14 #include "mlir/IR/MLIRContext.h"
15 #include "mlir/IR/Operation.h"
16 #include "llvm/ADT/MapVector.h"
17 #include "llvm/ADT/Twine.h"
18 #include "llvm/Support/ManagedStatic.h"
19 #include "llvm/Support/Regex.h"
20 
21 using namespace mlir;
22 using namespace detail;
23 
24 DialectAsmParser::~DialectAsmParser() {}
25 
26 //===----------------------------------------------------------------------===//
27 // Dialect Registration
28 //===----------------------------------------------------------------------===//
29 
30 /// Registry for all dialect allocation functions.
31 static llvm::ManagedStatic<llvm::MapVector<TypeID, DialectAllocatorFunction>>
32     dialectRegistry;
33 
34 /// Registry for functions that set dialect hooks.
35 static llvm::ManagedStatic<llvm::MapVector<TypeID, DialectHooksSetter>>
36     dialectHooksRegistry;
37 
38 void Dialect::registerDialectAllocator(
39     TypeID typeID, const DialectAllocatorFunction &function) {
40   assert(function &&
41          "Attempting to register an empty dialect initialize function");
42   dialectRegistry->insert({typeID, function});
43 }
44 
45 /// Registers a function to set specific hooks for a specific dialect, typically
46 /// used through the DialectHooksRegistration template.
47 void DialectHooks::registerDialectHooksSetter(
48     TypeID typeID, const DialectHooksSetter &function) {
49   assert(
50       function &&
51       "Attempting to register an empty dialect hooks initialization function");
52 
53   dialectHooksRegistry->insert({typeID, function});
54 }
55 
56 /// Registers all dialects and hooks from the global registries with the
57 /// specified MLIRContext.
58 void mlir::registerAllDialects(MLIRContext *context) {
59   for (const auto &it : *dialectRegistry)
60     it.second(context);
61   for (const auto &it : *dialectHooksRegistry)
62     it.second(context);
63 }
64 
65 //===----------------------------------------------------------------------===//
66 // Dialect
67 //===----------------------------------------------------------------------===//
68 
69 Dialect::Dialect(StringRef name, MLIRContext *context, TypeID id)
70     : name(name), dialectID(id), context(context) {
71   assert(isValidNamespace(name) && "invalid dialect namespace");
72 }
73 
74 Dialect::~Dialect() {}
75 
76 /// Verify an attribute from this dialect on the argument at 'argIndex' for
77 /// the region at 'regionIndex' on the given operation. Returns failure if
78 /// the verification failed, success otherwise. This hook may optionally be
79 /// invoked from any operation containing a region.
80 LogicalResult Dialect::verifyRegionArgAttribute(Operation *, unsigned, unsigned,
81                                                 NamedAttribute) {
82   return success();
83 }
84 
85 /// Verify an attribute from this dialect on the result at 'resultIndex' for
86 /// the region at 'regionIndex' on the given operation. Returns failure if
87 /// the verification failed, success otherwise. This hook may optionally be
88 /// invoked from any operation containing a region.
89 LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned,
90                                                    unsigned, NamedAttribute) {
91   return success();
92 }
93 
94 /// Parse an attribute registered to this dialect.
95 Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type) const {
96   parser.emitError(parser.getNameLoc())
97       << "dialect '" << getNamespace()
98       << "' provides no attribute parsing hook";
99   return Attribute();
100 }
101 
102 /// Parse a type registered to this dialect.
103 Type Dialect::parseType(DialectAsmParser &parser) const {
104   // If this dialect allows unknown types, then represent this with OpaqueType.
105   if (allowsUnknownTypes()) {
106     auto ns = Identifier::get(getNamespace(), getContext());
107     return OpaqueType::get(ns, parser.getFullSymbolSpec(), getContext());
108   }
109 
110   parser.emitError(parser.getNameLoc())
111       << "dialect '" << getNamespace() << "' provides no type parsing hook";
112   return Type();
113 }
114 
115 /// Utility function that returns if the given string is a valid dialect
116 /// namespace.
117 bool Dialect::isValidNamespace(StringRef str) {
118   if (str.empty())
119     return true;
120   llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$");
121   return dialectNameRegex.match(str);
122 }
123 
124 /// Register a set of dialect interfaces with this dialect instance.
125 void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
126   auto it = registeredInterfaces.try_emplace(interface->getID(),
127                                              std::move(interface));
128   (void)it;
129   assert(it.second && "interface kind has already been registered");
130 }
131 
132 //===----------------------------------------------------------------------===//
133 // Dialect Interface
134 //===----------------------------------------------------------------------===//
135 
136 DialectInterface::~DialectInterface() {}
137 
138 DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
139     MLIRContext *ctx, TypeID interfaceKind) {
140   for (auto *dialect : ctx->getRegisteredDialects()) {
141     if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
142       interfaces.insert(interface);
143       orderedInterfaces.push_back(interface);
144     }
145   }
146 }
147 
148 DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() {}
149 
150 /// Get the interface for the dialect of given operation, or null if one
151 /// is not registered.
152 const DialectInterface *
153 DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const {
154   return getInterfaceFor(op->getDialect());
155 }
156