xref: /llvm-project-15.0.7/mlir/lib/IR/Dialect.cpp (revision bb09ef95)
1 //===- Dialect.cpp - Dialect implementation -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Dialect.h"
10 #include "mlir/IR/Diagnostics.h"
11 #include "mlir/IR/DialectImplementation.h"
12 #include "mlir/IR/DialectInterface.h"
13 #include "mlir/IR/MLIRContext.h"
14 #include "mlir/IR/Operation.h"
15 #include "llvm/ADT/MapVector.h"
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Support/ManagedStatic.h"
18 #include "llvm/Support/Regex.h"
19 
20 using namespace mlir;
21 using namespace detail;
22 
23 DialectAsmParser::~DialectAsmParser() {}
24 
25 //===----------------------------------------------------------------------===//
26 // Dialect Registration (DEPRECATED)
27 //===----------------------------------------------------------------------===//
28 
29 /// Registry for all dialect allocation functions.
30 static llvm::ManagedStatic<DialectRegistry> dialectRegistry;
31 DialectRegistry &mlir::getGlobalDialectRegistry() { return *dialectRegistry; }
32 
33 // Note: deprecated, will be removed soon.
34 static bool isGlobalDialectRegistryEnabledFlag = false;
35 void mlir::enableGlobalDialectRegistry(bool enable) {
36   isGlobalDialectRegistryEnabledFlag = enable;
37 }
38 bool mlir::isGlobalDialectRegistryEnabled() {
39   return isGlobalDialectRegistryEnabledFlag;
40 }
41 
42 void mlir::registerAllDialects(MLIRContext *context) {
43   dialectRegistry->appendTo(context->getDialectRegistry());
44 }
45 
46 Dialect *DialectRegistry::loadByName(StringRef name, MLIRContext *context) {
47   auto it = registry.find(name.str());
48   if (it == registry.end())
49     return nullptr;
50   return it->second.second(context);
51 }
52 
53 void DialectRegistry::insert(TypeID typeID, StringRef name,
54                              DialectAllocatorFunction ctor) {
55   auto inserted = registry.insert(
56       std::make_pair(std::string(name), std::make_pair(typeID, ctor)));
57   if (!inserted.second && inserted.first->second.first != typeID) {
58     llvm::report_fatal_error(
59         "Trying to register different dialects for the same namespace: " +
60         name);
61   }
62 }
63 
64 //===----------------------------------------------------------------------===//
65 // Dialect
66 //===----------------------------------------------------------------------===//
67 
68 Dialect::Dialect(StringRef name, MLIRContext *context, TypeID id)
69     : name(name), dialectID(id), context(context) {
70   assert(isValidNamespace(name) && "invalid dialect namespace");
71 }
72 
73 Dialect::~Dialect() {}
74 
75 /// Verify an attribute from this dialect on the argument at 'argIndex' for
76 /// the region at 'regionIndex' on the given operation. Returns failure if
77 /// the verification failed, success otherwise. This hook may optionally be
78 /// invoked from any operation containing a region.
79 LogicalResult Dialect::verifyRegionArgAttribute(Operation *, unsigned, unsigned,
80                                                 NamedAttribute) {
81   return success();
82 }
83 
84 /// Verify an attribute from this dialect on the result at 'resultIndex' for
85 /// the region at 'regionIndex' on the given operation. Returns failure if
86 /// the verification failed, success otherwise. This hook may optionally be
87 /// invoked from any operation containing a region.
88 LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned,
89                                                    unsigned, NamedAttribute) {
90   return success();
91 }
92 
93 /// Parse an attribute registered to this dialect.
94 Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type) const {
95   parser.emitError(parser.getNameLoc())
96       << "dialect '" << getNamespace()
97       << "' provides no attribute parsing hook";
98   return Attribute();
99 }
100 
101 /// Parse a type registered to this dialect.
102 Type Dialect::parseType(DialectAsmParser &parser) const {
103   // If this dialect allows unknown types, then represent this with OpaqueType.
104   if (allowsUnknownTypes()) {
105     auto ns = Identifier::get(getNamespace(), getContext());
106     return OpaqueType::get(ns, parser.getFullSymbolSpec(), getContext());
107   }
108 
109   parser.emitError(parser.getNameLoc())
110       << "dialect '" << getNamespace() << "' provides no type parsing hook";
111   return Type();
112 }
113 
114 /// Utility function that returns if the given string is a valid dialect
115 /// namespace.
116 bool Dialect::isValidNamespace(StringRef str) {
117   if (str.empty())
118     return true;
119   llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$");
120   return dialectNameRegex.match(str);
121 }
122 
123 /// Register a set of dialect interfaces with this dialect instance.
124 void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
125   auto it = registeredInterfaces.try_emplace(interface->getID(),
126                                              std::move(interface));
127   (void)it;
128   assert(it.second && "interface kind has already been registered");
129 }
130 
131 //===----------------------------------------------------------------------===//
132 // Dialect Interface
133 //===----------------------------------------------------------------------===//
134 
135 DialectInterface::~DialectInterface() {}
136 
137 DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
138     MLIRContext *ctx, TypeID interfaceKind) {
139   for (auto *dialect : ctx->getLoadedDialects()) {
140     if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
141       interfaces.insert(interface);
142       orderedInterfaces.push_back(interface);
143     }
144   }
145 }
146 
147 DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() {}
148 
149 /// Get the interface for the dialect of given operation, or null if one
150 /// is not registered.
151 const DialectInterface *
152 DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const {
153   return getInterfaceFor(op->getDialect());
154 }
155