xref: /llvm-project-15.0.7/mlir/lib/IR/Dialect.cpp (revision cbc9d22e)
1 //===- Dialect.cpp - Dialect implementation -------------------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Dialect.h"
10 #include "mlir/IR/Diagnostics.h"
11 #include "mlir/IR/DialectHooks.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/IR/DialectInterface.h"
14 #include "mlir/IR/MLIRContext.h"
15 #include "mlir/IR/Operation.h"
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Support/ManagedStatic.h"
18 #include "llvm/Support/Regex.h"
19 
20 using namespace mlir;
21 using namespace detail;
22 
23 DialectAsmParser::~DialectAsmParser() {}
24 
25 //===----------------------------------------------------------------------===//
26 // Dialect Registration
27 //===----------------------------------------------------------------------===//
28 
29 // Registry for all dialect allocation functions.
30 static llvm::ManagedStatic<SmallVector<DialectAllocatorFunction, 8>>
31     dialectRegistry;
32 
33 // Registry for functions that set dialect hooks.
34 static llvm::ManagedStatic<SmallVector<DialectHooksSetter, 8>>
35     dialectHooksRegistry;
36 
37 /// Registers a specific dialect creation function with the system, typically
38 /// used through the DialectRegistration template.
39 void mlir::registerDialectAllocator(const DialectAllocatorFunction &function) {
40   assert(function &&
41          "Attempting to register an empty dialect initialize function");
42   dialectRegistry->push_back(function);
43 }
44 
45 /// Registers a function to set specific hooks for a specific dialect, typically
46 /// used through the DialectHooksRegistration template.
47 void mlir::registerDialectHooksSetter(const DialectHooksSetter &function) {
48   assert(
49       function &&
50       "Attempting to register an empty dialect hooks initialization function");
51 
52   dialectHooksRegistry->push_back(function);
53 }
54 
55 /// Registers all dialects and their const folding hooks with the specified
56 /// MLIRContext.
57 void mlir::registerAllDialects(MLIRContext *context) {
58   for (const auto &fn : *dialectRegistry)
59     fn(context);
60   for (const auto &fn : *dialectHooksRegistry) {
61     fn(context);
62   }
63 }
64 
65 //===----------------------------------------------------------------------===//
66 // Dialect
67 //===----------------------------------------------------------------------===//
68 
69 Dialect::Dialect(StringRef name, MLIRContext *context)
70     : name(name), context(context) {
71   assert(isValidNamespace(name) && "invalid dialect namespace");
72   registerDialect(context);
73 }
74 
75 Dialect::~Dialect() {}
76 
77 /// Verify an attribute from this dialect on the argument at 'argIndex' for
78 /// the region at 'regionIndex' on the given operation. Returns failure if
79 /// the verification failed, success otherwise. This hook may optionally be
80 /// invoked from any operation containing a region.
81 LogicalResult Dialect::verifyRegionArgAttribute(Operation *, unsigned, unsigned,
82                                                 NamedAttribute) {
83   return success();
84 }
85 
86 /// Verify an attribute from this dialect on the result at 'resultIndex' for
87 /// the region at 'regionIndex' on the given operation. Returns failure if
88 /// the verification failed, success otherwise. This hook may optionally be
89 /// invoked from any operation containing a region.
90 LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned,
91                                                    unsigned, NamedAttribute) {
92   return success();
93 }
94 
95 /// Parse an attribute registered to this dialect.
96 Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type) const {
97   parser.emitError(parser.getNameLoc())
98       << "dialect '" << getNamespace()
99       << "' provides no attribute parsing hook";
100   return Attribute();
101 }
102 
103 /// Parse a type registered to this dialect.
104 Type Dialect::parseType(DialectAsmParser &parser) const {
105   // If this dialect allows unknown types, then represent this with OpaqueType.
106   if (allowsUnknownTypes()) {
107     auto ns = Identifier::get(getNamespace(), getContext());
108     return OpaqueType::get(ns, parser.getFullSymbolSpec(), getContext());
109   }
110 
111   parser.emitError(parser.getNameLoc())
112       << "dialect '" << getNamespace() << "' provides no type parsing hook";
113   return Type();
114 }
115 
116 /// Utility function that returns if the given string is a valid dialect
117 /// namespace.
118 bool Dialect::isValidNamespace(StringRef str) {
119   if (str.empty())
120     return true;
121   llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$");
122   return dialectNameRegex.match(str);
123 }
124 
125 /// Register a set of dialect interfaces with this dialect instance.
126 void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
127   auto it = registeredInterfaces.try_emplace(interface->getID(),
128                                              std::move(interface));
129   (void)it;
130   assert(it.second && "interface kind has already been registered");
131 }
132 
133 //===----------------------------------------------------------------------===//
134 // Dialect Interface
135 //===----------------------------------------------------------------------===//
136 
137 DialectInterface::~DialectInterface() {}
138 
139 DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
140     MLIRContext *ctx, ClassID *interfaceKind) {
141   for (auto *dialect : ctx->getRegisteredDialects()) {
142     if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
143       interfaces.insert(interface);
144       orderedInterfaces.push_back(interface);
145     }
146   }
147 }
148 
149 DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() {}
150 
151 /// Get the interface for the dialect of given operation, or null if one
152 /// is not registered.
153 const DialectInterface *
154 DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const {
155   return getInterfaceFor(op->getDialect());
156 }
157