1 //===- Dialect.cpp - Dialect implementation -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/IR/Dialect.h"
10 #include "mlir/IR/BuiltinDialect.h"
11 #include "mlir/IR/Diagnostics.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/IR/DialectInterface.h"
14 #include "mlir/IR/MLIRContext.h"
15 #include "mlir/IR/Operation.h"
16 #include "llvm/ADT/MapVector.h"
17 #include "llvm/ADT/Twine.h"
18 #include "llvm/Support/Debug.h"
19 #include "llvm/Support/ManagedStatic.h"
20 #include "llvm/Support/Regex.h"
21
22 #define DEBUG_TYPE "dialect"
23
24 using namespace mlir;
25 using namespace detail;
26
27 //===----------------------------------------------------------------------===//
28 // Dialect
29 //===----------------------------------------------------------------------===//
30
Dialect(StringRef name,MLIRContext * context,TypeID id)31 Dialect::Dialect(StringRef name, MLIRContext *context, TypeID id)
32 : name(name), dialectID(id), context(context) {
33 assert(isValidNamespace(name) && "invalid dialect namespace");
34 }
35
36 Dialect::~Dialect() = default;
37
38 /// Verify an attribute from this dialect on the argument at 'argIndex' for
39 /// the region at 'regionIndex' on the given operation. Returns failure if
40 /// the verification failed, success otherwise. This hook may optionally be
41 /// invoked from any operation containing a region.
verifyRegionArgAttribute(Operation *,unsigned,unsigned,NamedAttribute)42 LogicalResult Dialect::verifyRegionArgAttribute(Operation *, unsigned, unsigned,
43 NamedAttribute) {
44 return success();
45 }
46
47 /// Verify an attribute from this dialect on the result at 'resultIndex' for
48 /// the region at 'regionIndex' on the given operation. Returns failure if
49 /// the verification failed, success otherwise. This hook may optionally be
50 /// invoked from any operation containing a region.
verifyRegionResultAttribute(Operation *,unsigned,unsigned,NamedAttribute)51 LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned,
52 unsigned, NamedAttribute) {
53 return success();
54 }
55
56 /// Parse an attribute registered to this dialect.
parseAttribute(DialectAsmParser & parser,Type type) const57 Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type) const {
58 parser.emitError(parser.getNameLoc())
59 << "dialect '" << getNamespace()
60 << "' provides no attribute parsing hook";
61 return Attribute();
62 }
63
64 /// Parse a type registered to this dialect.
parseType(DialectAsmParser & parser) const65 Type Dialect::parseType(DialectAsmParser &parser) const {
66 // If this dialect allows unknown types, then represent this with OpaqueType.
67 if (allowsUnknownTypes()) {
68 StringAttr ns = StringAttr::get(getContext(), getNamespace());
69 return OpaqueType::get(ns, parser.getFullSymbolSpec());
70 }
71
72 parser.emitError(parser.getNameLoc())
73 << "dialect '" << getNamespace() << "' provides no type parsing hook";
74 return Type();
75 }
76
77 Optional<Dialect::ParseOpHook>
getParseOperationHook(StringRef opName) const78 Dialect::getParseOperationHook(StringRef opName) const {
79 return None;
80 }
81
82 llvm::unique_function<void(Operation *, OpAsmPrinter &printer)>
getOperationPrinter(Operation * op) const83 Dialect::getOperationPrinter(Operation *op) const {
84 assert(op->getDialect() == this &&
85 "Dialect hook invoked on non-dialect owned operation");
86 return nullptr;
87 }
88
89 /// Utility function that returns if the given string is a valid dialect
90 /// namespace
isValidNamespace(StringRef str)91 bool Dialect::isValidNamespace(StringRef str) {
92 llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$");
93 return dialectNameRegex.match(str);
94 }
95
96 /// Register a set of dialect interfaces with this dialect instance.
addInterface(std::unique_ptr<DialectInterface> interface)97 void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
98 auto it = registeredInterfaces.try_emplace(interface->getID(),
99 std::move(interface));
100 (void)it;
101 LLVM_DEBUG({
102 if (!it.second) {
103 llvm::dbgs() << "[" DEBUG_TYPE
104 "] repeated interface registration for dialect "
105 << getNamespace();
106 }
107 });
108 }
109
110 //===----------------------------------------------------------------------===//
111 // Dialect Interface
112 //===----------------------------------------------------------------------===//
113
114 DialectInterface::~DialectInterface() = default;
115
DialectInterfaceCollectionBase(MLIRContext * ctx,TypeID interfaceKind)116 DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
117 MLIRContext *ctx, TypeID interfaceKind) {
118 for (auto *dialect : ctx->getLoadedDialects()) {
119 if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
120 interfaces.insert(interface);
121 orderedInterfaces.push_back(interface);
122 }
123 }
124 }
125
126 DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() = default;
127
128 /// Get the interface for the dialect of given operation, or null if one
129 /// is not registered.
130 const DialectInterface *
getInterfaceFor(Operation * op) const131 DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const {
132 return getInterfaceFor(op->getDialect());
133 }
134
135 //===----------------------------------------------------------------------===//
136 // DialectExtension
137 //===----------------------------------------------------------------------===//
138
139 DialectExtensionBase::~DialectExtensionBase() = default;
140
141 //===----------------------------------------------------------------------===//
142 // DialectRegistry
143 //===----------------------------------------------------------------------===//
144
DialectRegistry()145 DialectRegistry::DialectRegistry() { insert<BuiltinDialect>(); }
146
147 DialectAllocatorFunctionRef
getDialectAllocator(StringRef name) const148 DialectRegistry::getDialectAllocator(StringRef name) const {
149 auto it = registry.find(name.str());
150 if (it == registry.end())
151 return nullptr;
152 return it->second.second;
153 }
154
insert(TypeID typeID,StringRef name,const DialectAllocatorFunction & ctor)155 void DialectRegistry::insert(TypeID typeID, StringRef name,
156 const DialectAllocatorFunction &ctor) {
157 auto inserted = registry.insert(
158 std::make_pair(std::string(name), std::make_pair(typeID, ctor)));
159 if (!inserted.second && inserted.first->second.first != typeID) {
160 llvm::report_fatal_error(
161 "Trying to register different dialects for the same namespace: " +
162 name);
163 }
164 }
165
applyExtensions(Dialect * dialect) const166 void DialectRegistry::applyExtensions(Dialect *dialect) const {
167 MLIRContext *ctx = dialect->getContext();
168 StringRef dialectName = dialect->getNamespace();
169
170 // Functor used to try to apply the given extension.
171 auto applyExtension = [&](const DialectExtensionBase &extension) {
172 ArrayRef<StringRef> dialectNames = extension.getRequiredDialects();
173
174 // Handle the simple case of a single dialect name. In this case, the
175 // required dialect should be the current dialect.
176 if (dialectNames.size() == 1) {
177 if (dialectNames.front() == dialectName)
178 extension.apply(ctx, dialect);
179 return;
180 }
181
182 // Otherwise, check to see if this extension requires this dialect.
183 const StringRef *nameIt = llvm::find(dialectNames, dialectName);
184 if (nameIt == dialectNames.end())
185 return;
186
187 // If it does, ensure that all of the other required dialects have been
188 // loaded.
189 SmallVector<Dialect *> requiredDialects;
190 requiredDialects.reserve(dialectNames.size());
191 for (auto it = dialectNames.begin(), e = dialectNames.end(); it != e;
192 ++it) {
193 // The current dialect is known to be loaded.
194 if (it == nameIt) {
195 requiredDialects.push_back(dialect);
196 continue;
197 }
198 // Otherwise, check if it is loaded.
199 Dialect *loadedDialect = ctx->getLoadedDialect(*it);
200 if (!loadedDialect)
201 return;
202 requiredDialects.push_back(loadedDialect);
203 }
204 extension.apply(ctx, requiredDialects);
205 };
206
207 for (const auto &extension : extensions)
208 applyExtension(*extension);
209 }
210
applyExtensions(MLIRContext * ctx) const211 void DialectRegistry::applyExtensions(MLIRContext *ctx) const {
212 // Functor used to try to apply the given extension.
213 auto applyExtension = [&](const DialectExtensionBase &extension) {
214 ArrayRef<StringRef> dialectNames = extension.getRequiredDialects();
215
216 // Check to see if all of the dialects for this extension are loaded.
217 SmallVector<Dialect *> requiredDialects;
218 requiredDialects.reserve(dialectNames.size());
219 for (StringRef dialectName : dialectNames) {
220 Dialect *loadedDialect = ctx->getLoadedDialect(dialectName);
221 if (!loadedDialect)
222 return;
223 requiredDialects.push_back(loadedDialect);
224 }
225 extension.apply(ctx, requiredDialects);
226 };
227
228 for (const auto &extension : extensions)
229 applyExtension(*extension);
230 }
231
isSubsetOf(const DialectRegistry & rhs) const232 bool DialectRegistry::isSubsetOf(const DialectRegistry &rhs) const {
233 // Treat any extensions conservatively.
234 if (!extensions.empty())
235 return false;
236 // Check that the current dialects fully overlap with the dialects in 'rhs'.
237 return llvm::all_of(
238 registry, [&](const auto &it) { return rhs.registry.count(it.first); });
239 }
240