1 //===- ExtensibleDialect.cpp - Extensible dialect ---------------*- C++ -*-===//
2 //
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/ExtensibleDialect.h"
10 #include "mlir/IR/AttributeSupport.h"
11 #include "mlir/IR/DialectImplementation.h"
12 #include "mlir/IR/OperationSupport.h"
13 #include "mlir/IR/StorageUniquerSupport.h"
14 #include "mlir/Support/LogicalResult.h"
15 
16 using namespace mlir;
17 
18 //===----------------------------------------------------------------------===//
19 // Dynamic types and attributes shared functions
20 //===----------------------------------------------------------------------===//
21 
22 /// Default parser for dynamic attribute or type parameters.
23 /// Parse in the format '(<>)?' or '<attr (,attr)*>'.
24 static LogicalResult
25 typeOrAttrParser(AsmParser &parser, SmallVectorImpl<Attribute> &parsedParams) {
26   // No parameters
27   if (parser.parseOptionalLess() || !parser.parseOptionalGreater())
28     return success();
29 
30   Attribute attr;
31   if (parser.parseAttribute(attr))
32     return failure();
33   parsedParams.push_back(attr);
34 
35   while (parser.parseOptionalGreater()) {
36     Attribute attr;
37     if (parser.parseComma() || parser.parseAttribute(attr))
38       return failure();
39     parsedParams.push_back(attr);
40   }
41 
42   return success();
43 }
44 
45 /// Default printer for dynamic attribute or type parameters.
46 /// Print in the format '(<>)?' or '<attr (,attr)*>'.
47 static void typeOrAttrPrinter(AsmPrinter &printer, ArrayRef<Attribute> params) {
48   if (params.empty())
49     return;
50 
51   printer << "<";
52   interleaveComma(params, printer.getStream());
53   printer << ">";
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // Dynamic type
58 //===----------------------------------------------------------------------===//
59 
60 std::unique_ptr<DynamicTypeDefinition>
61 DynamicTypeDefinition::get(StringRef name, ExtensibleDialect *dialect,
62                            VerifierFn &&verifier) {
63   return DynamicTypeDefinition::get(name, dialect, std::move(verifier),
64                                     typeOrAttrParser, typeOrAttrPrinter);
65 }
66 
67 std::unique_ptr<DynamicTypeDefinition>
68 DynamicTypeDefinition::get(StringRef name, ExtensibleDialect *dialect,
69                            VerifierFn &&verifier, ParserFn &&parser,
70                            PrinterFn &&printer) {
71   return std::unique_ptr<DynamicTypeDefinition>(
72       new DynamicTypeDefinition(name, dialect, std::move(verifier),
73                                 std::move(parser), std::move(printer)));
74 }
75 
76 DynamicTypeDefinition::DynamicTypeDefinition(StringRef nameRef,
77                                              ExtensibleDialect *dialect,
78                                              VerifierFn &&verifier,
79                                              ParserFn &&parser,
80                                              PrinterFn &&printer)
81     : name(nameRef), dialect(dialect), verifier(std::move(verifier)),
82       parser(std::move(parser)), printer(std::move(printer)),
83       ctx(dialect->getContext()) {
84   assert(!nameRef.contains('.') &&
85          "name should not be prefixed by the dialect name");
86 }
87 
88 DynamicTypeDefinition::DynamicTypeDefinition(ExtensibleDialect *dialect,
89                                              StringRef nameRef)
90     : name(nameRef), dialect(dialect), ctx(dialect->getContext()) {
91   assert(!nameRef.contains('.') &&
92          "name should not be prefixed by the dialect name");
93 }
94 
95 void DynamicTypeDefinition::registerInTypeUniquer() {
96   detail::TypeUniquer::registerType<DynamicType>(&getContext(), getTypeID());
97 }
98 
99 namespace mlir {
100 namespace detail {
101 /// Storage of DynamicType.
102 /// Contains a pointer to the type definition and type parameters.
103 struct DynamicTypeStorage : public TypeStorage {
104 
105   using KeyTy = std::pair<DynamicTypeDefinition *, ArrayRef<Attribute>>;
106 
107   explicit DynamicTypeStorage(DynamicTypeDefinition *typeDef,
108                               ArrayRef<Attribute> params)
109       : typeDef(typeDef), params(params) {}
110 
111   bool operator==(const KeyTy &key) const {
112     return typeDef == key.first && params == key.second;
113   }
114 
115   static llvm::hash_code hashKey(const KeyTy &key) {
116     return llvm::hash_value(key);
117   }
118 
119   static DynamicTypeStorage *construct(TypeStorageAllocator &alloc,
120                                        const KeyTy &key) {
121     return new (alloc.allocate<DynamicTypeStorage>())
122         DynamicTypeStorage(key.first, alloc.copyInto(key.second));
123   }
124 
125   /// Definition of the type.
126   DynamicTypeDefinition *typeDef;
127 
128   /// The type parameters.
129   ArrayRef<Attribute> params;
130 };
131 } // namespace detail
132 } // namespace mlir
133 
134 DynamicType DynamicType::get(DynamicTypeDefinition *typeDef,
135                              ArrayRef<Attribute> params) {
136   auto &ctx = typeDef->getContext();
137   auto emitError = detail::getDefaultDiagnosticEmitFn(&ctx);
138   assert(succeeded(typeDef->verify(emitError, params)));
139   return detail::TypeUniquer::getWithTypeID<DynamicType>(
140       &ctx, typeDef->getTypeID(), typeDef, params);
141 }
142 
143 DynamicType
144 DynamicType::getChecked(function_ref<InFlightDiagnostic()> emitError,
145                         DynamicTypeDefinition *typeDef,
146                         ArrayRef<Attribute> params) {
147   if (failed(typeDef->verify(emitError, params)))
148     return {};
149   auto &ctx = typeDef->getContext();
150   return detail::TypeUniquer::getWithTypeID<DynamicType>(
151       &ctx, typeDef->getTypeID(), typeDef, params);
152 }
153 
154 DynamicTypeDefinition *DynamicType::getTypeDef() { return getImpl()->typeDef; }
155 
156 ArrayRef<Attribute> DynamicType::getParams() { return getImpl()->params; }
157 
158 bool DynamicType::classof(Type type) {
159   return type.hasTrait<IsDynamicTypeTrait>();
160 }
161 
162 ParseResult DynamicType::parse(AsmParser &parser,
163                                DynamicTypeDefinition *typeDef,
164                                DynamicType &parsedType) {
165   SmallVector<Attribute> params;
166   if (failed(typeDef->parser(parser, params)))
167     return failure();
168   parsedType = parser.getChecked<DynamicType>(typeDef, params);
169   if (!parsedType)
170     return failure();
171   return success();
172 }
173 
174 void DynamicType::print(AsmPrinter &printer) {
175   printer << getTypeDef()->getName();
176   getTypeDef()->printer(printer, getParams());
177 }
178 
179 //===----------------------------------------------------------------------===//
180 // Dynamic attribute
181 //===----------------------------------------------------------------------===//
182 
183 std::unique_ptr<DynamicAttrDefinition>
184 DynamicAttrDefinition::get(StringRef name, ExtensibleDialect *dialect,
185                            VerifierFn &&verifier) {
186   return DynamicAttrDefinition::get(name, dialect, std::move(verifier),
187                                     typeOrAttrParser, typeOrAttrPrinter);
188 }
189 
190 std::unique_ptr<DynamicAttrDefinition>
191 DynamicAttrDefinition::get(StringRef name, ExtensibleDialect *dialect,
192                            VerifierFn &&verifier, ParserFn &&parser,
193                            PrinterFn &&printer) {
194   return std::unique_ptr<DynamicAttrDefinition>(
195       new DynamicAttrDefinition(name, dialect, std::move(verifier),
196                                 std::move(parser), std::move(printer)));
197 }
198 
199 DynamicAttrDefinition::DynamicAttrDefinition(StringRef nameRef,
200                                              ExtensibleDialect *dialect,
201                                              VerifierFn &&verifier,
202                                              ParserFn &&parser,
203                                              PrinterFn &&printer)
204     : name(nameRef), dialect(dialect), verifier(std::move(verifier)),
205       parser(std::move(parser)), printer(std::move(printer)),
206       ctx(dialect->getContext()) {
207   assert(!nameRef.contains('.') &&
208          "name should not be prefixed by the dialect name");
209 }
210 
211 DynamicAttrDefinition::DynamicAttrDefinition(ExtensibleDialect *dialect,
212                                              StringRef nameRef)
213     : name(nameRef), dialect(dialect), ctx(dialect->getContext()) {
214   assert(!nameRef.contains('.') &&
215          "name should not be prefixed by the dialect name");
216 }
217 
218 void DynamicAttrDefinition::registerInAttrUniquer() {
219   detail::AttributeUniquer::registerAttribute<DynamicAttr>(&getContext(),
220                                                            getTypeID());
221 }
222 
223 namespace mlir {
224 namespace detail {
225 /// Storage of DynamicAttr.
226 /// Contains a pointer to the attribute definition and attribute parameters.
227 struct DynamicAttrStorage : public AttributeStorage {
228   using KeyTy = std::pair<DynamicAttrDefinition *, ArrayRef<Attribute>>;
229 
230   explicit DynamicAttrStorage(DynamicAttrDefinition *attrDef,
231                               ArrayRef<Attribute> params)
232       : attrDef(attrDef), params(params) {}
233 
234   bool operator==(const KeyTy &key) const {
235     return attrDef == key.first && params == key.second;
236   }
237 
238   static llvm::hash_code hashKey(const KeyTy &key) {
239     return llvm::hash_value(key);
240   }
241 
242   static DynamicAttrStorage *construct(AttributeStorageAllocator &alloc,
243                                        const KeyTy &key) {
244     return new (alloc.allocate<DynamicAttrStorage>())
245         DynamicAttrStorage(key.first, alloc.copyInto(key.second));
246   }
247 
248   /// Definition of the type.
249   DynamicAttrDefinition *attrDef;
250 
251   /// The type parameters.
252   ArrayRef<Attribute> params;
253 };
254 } // namespace detail
255 } // namespace mlir
256 
257 DynamicAttr DynamicAttr::get(DynamicAttrDefinition *attrDef,
258                              ArrayRef<Attribute> params) {
259   auto &ctx = attrDef->getContext();
260   return detail::AttributeUniquer::getWithTypeID<DynamicAttr>(
261       &ctx, attrDef->getTypeID(), attrDef, params);
262 }
263 
264 DynamicAttr
265 DynamicAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
266                         DynamicAttrDefinition *attrDef,
267                         ArrayRef<Attribute> params) {
268   if (failed(attrDef->verify(emitError, params)))
269     return {};
270   return get(attrDef, params);
271 }
272 
273 DynamicAttrDefinition *DynamicAttr::getAttrDef() { return getImpl()->attrDef; }
274 
275 ArrayRef<Attribute> DynamicAttr::getParams() { return getImpl()->params; }
276 
277 bool DynamicAttr::classof(Attribute attr) {
278   return attr.hasTrait<IsDynamicAttrTrait>();
279 }
280 
281 ParseResult DynamicAttr::parse(AsmParser &parser,
282                                DynamicAttrDefinition *attrDef,
283                                DynamicAttr &parsedAttr) {
284   SmallVector<Attribute> params;
285   if (failed(attrDef->parser(parser, params)))
286     return failure();
287   parsedAttr = parser.getChecked<DynamicAttr>(attrDef, params);
288   if (!parsedAttr)
289     return failure();
290   return success();
291 }
292 
293 void DynamicAttr::print(AsmPrinter &printer) {
294   printer << getAttrDef()->getName();
295   getAttrDef()->printer(printer, getParams());
296 }
297 
298 //===----------------------------------------------------------------------===//
299 // Dynamic operation
300 //===----------------------------------------------------------------------===//
301 
302 DynamicOpDefinition::DynamicOpDefinition(
303     StringRef name, ExtensibleDialect *dialect,
304     OperationName::VerifyInvariantsFn &&verifyFn,
305     OperationName::VerifyRegionInvariantsFn &&verifyRegionFn,
306     OperationName::ParseAssemblyFn &&parseFn,
307     OperationName::PrintAssemblyFn &&printFn,
308     OperationName::FoldHookFn &&foldHookFn,
309     OperationName::GetCanonicalizationPatternsFn
310         &&getCanonicalizationPatternsFn)
311     : typeID(dialect->allocateTypeID()),
312       name((dialect->getNamespace() + "." + name).str()), dialect(dialect),
313       verifyFn(std::move(verifyFn)), verifyRegionFn(std::move(verifyRegionFn)),
314       parseFn(std::move(parseFn)), printFn(std::move(printFn)),
315       foldHookFn(std::move(foldHookFn)),
316       getCanonicalizationPatternsFn(std::move(getCanonicalizationPatternsFn)) {
317   assert(!name.contains('.') &&
318          "name should not be prefixed by the dialect name");
319 }
320 
321 std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get(
322     StringRef name, ExtensibleDialect *dialect,
323     OperationName::VerifyInvariantsFn &&verifyFn,
324     OperationName::VerifyRegionInvariantsFn &&verifyRegionFn) {
325   auto parseFn = [](OpAsmParser &parser, OperationState &result) {
326     return parser.emitError(
327         parser.getCurrentLocation(),
328         "dynamic operation do not define any parser function");
329   };
330 
331   auto printFn = [](Operation *op, OpAsmPrinter &printer, StringRef) {
332     printer.printGenericOp(op);
333   };
334 
335   return DynamicOpDefinition::get(name, dialect, std::move(verifyFn),
336                                   std::move(verifyRegionFn), std::move(parseFn),
337                                   std::move(printFn));
338 }
339 
340 std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get(
341     StringRef name, ExtensibleDialect *dialect,
342     OperationName::VerifyInvariantsFn &&verifyFn,
343     OperationName::VerifyRegionInvariantsFn &&verifyRegionFn,
344     OperationName::ParseAssemblyFn &&parseFn,
345     OperationName::PrintAssemblyFn &&printFn) {
346   auto foldHookFn = [](Operation *op, ArrayRef<Attribute> operands,
347                        SmallVectorImpl<OpFoldResult> &results) {
348     return failure();
349   };
350 
351   auto getCanonicalizationPatternsFn = [](RewritePatternSet &, MLIRContext *) {
352   };
353 
354   return DynamicOpDefinition::get(name, dialect, std::move(verifyFn),
355                                   std::move(verifyRegionFn), std::move(parseFn),
356                                   std::move(printFn), std::move(foldHookFn),
357                                   std::move(getCanonicalizationPatternsFn));
358 }
359 
360 std::unique_ptr<DynamicOpDefinition>
361 DynamicOpDefinition::get(StringRef name, ExtensibleDialect *dialect,
362                          OperationName::VerifyInvariantsFn &&verifyFn,
363                          OperationName::VerifyInvariantsFn &&verifyRegionFn,
364                          OperationName::ParseAssemblyFn &&parseFn,
365                          OperationName::PrintAssemblyFn &&printFn,
366                          OperationName::FoldHookFn &&foldHookFn,
367                          OperationName::GetCanonicalizationPatternsFn
368                              &&getCanonicalizationPatternsFn) {
369   return std::unique_ptr<DynamicOpDefinition>(new DynamicOpDefinition(
370       name, dialect, std::move(verifyFn), std::move(verifyRegionFn),
371       std::move(parseFn), std::move(printFn), std::move(foldHookFn),
372       std::move(getCanonicalizationPatternsFn)));
373 }
374 
375 //===----------------------------------------------------------------------===//
376 // Extensible dialect
377 //===----------------------------------------------------------------------===//
378 
379 namespace {
380 /// Interface that can only be implemented by extensible dialects.
381 /// The interface is used to check if a dialect is extensible or not.
382 class IsExtensibleDialect : public DialectInterface::Base<IsExtensibleDialect> {
383 public:
384   IsExtensibleDialect(Dialect *dialect) : Base(dialect) {}
385 };
386 } // namespace
387 
388 ExtensibleDialect::ExtensibleDialect(StringRef name, MLIRContext *ctx,
389                                      TypeID typeID)
390     : Dialect(name, ctx, typeID) {
391   addInterfaces<IsExtensibleDialect>();
392 }
393 
394 void ExtensibleDialect::registerDynamicType(
395     std::unique_ptr<DynamicTypeDefinition> &&type) {
396   DynamicTypeDefinition *typePtr = type.get();
397   TypeID typeID = type->getTypeID();
398   StringRef name = type->getName();
399   ExtensibleDialect *dialect = type->getDialect();
400 
401   assert(dialect == this &&
402          "trying to register a dynamic type in the wrong dialect");
403 
404   // If a type with the same name is already defined, fail.
405   auto registered = dynTypes.try_emplace(typeID, std::move(type)).second;
406   (void)registered;
407   assert(registered && "type TypeID was not unique");
408 
409   registered = nameToDynTypes.insert({name, typePtr}).second;
410   (void)registered;
411   assert(registered &&
412          "Trying to create a new dynamic type with an existing name");
413 
414   auto abstractType =
415       AbstractType::get(*dialect, DynamicAttr::getInterfaceMap(),
416                         DynamicType::getHasTraitFn(), typeID);
417 
418   /// Add the type to the dialect and the type uniquer.
419   addType(typeID, std::move(abstractType));
420   typePtr->registerInTypeUniquer();
421 }
422 
423 void ExtensibleDialect::registerDynamicAttr(
424     std::unique_ptr<DynamicAttrDefinition> &&attr) {
425   auto *attrPtr = attr.get();
426   auto typeID = attr->getTypeID();
427   auto name = attr->getName();
428   auto *dialect = attr->getDialect();
429 
430   assert(dialect == this &&
431          "trying to register a dynamic attribute in the wrong dialect");
432 
433   // If an attribute with the same name is already defined, fail.
434   auto registered = dynAttrs.try_emplace(typeID, std::move(attr)).second;
435   (void)registered;
436   assert(registered && "attribute TypeID was not unique");
437 
438   registered = nameToDynAttrs.insert({name, attrPtr}).second;
439   (void)registered;
440   assert(registered &&
441          "Trying to create a new dynamic attribute with an existing name");
442 
443   auto abstractAttr =
444       AbstractAttribute::get(*dialect, DynamicAttr::getInterfaceMap(),
445                              DynamicAttr::getHasTraitFn(), typeID);
446 
447   /// Add the type to the dialect and the type uniquer.
448   addAttribute(typeID, std::move(abstractAttr));
449   attrPtr->registerInAttrUniquer();
450 }
451 
452 void ExtensibleDialect::registerDynamicOp(
453     std::unique_ptr<DynamicOpDefinition> &&op) {
454   assert(op->dialect == this &&
455          "trying to register a dynamic op in the wrong dialect");
456   auto hasTraitFn = [](TypeID traitId) { return false; };
457 
458   RegisteredOperationName::insert(
459       op->name, *op->dialect, op->typeID, std::move(op->parseFn),
460       std::move(op->printFn), std::move(op->verifyFn),
461       std::move(op->verifyRegionFn), std::move(op->foldHookFn),
462       std::move(op->getCanonicalizationPatternsFn),
463       detail::InterfaceMap::get<>(), std::move(hasTraitFn), {});
464 }
465 
466 bool ExtensibleDialect::classof(const Dialect *dialect) {
467   return const_cast<Dialect *>(dialect)
468       ->getRegisteredInterface<IsExtensibleDialect>();
469 }
470 
471 OptionalParseResult ExtensibleDialect::parseOptionalDynamicType(
472     StringRef typeName, AsmParser &parser, Type &resultType) const {
473   DynamicTypeDefinition *typeDef = lookupTypeDefinition(typeName);
474   if (!typeDef)
475     return llvm::None;
476 
477   DynamicType dynType;
478   if (DynamicType::parse(parser, typeDef, dynType))
479     return failure();
480   resultType = dynType;
481   return success();
482 }
483 
484 LogicalResult ExtensibleDialect::printIfDynamicType(Type type,
485                                                     AsmPrinter &printer) {
486   if (auto dynType = type.dyn_cast<DynamicType>()) {
487     dynType.print(printer);
488     return success();
489   }
490   return failure();
491 }
492 
493 OptionalParseResult ExtensibleDialect::parseOptionalDynamicAttr(
494     StringRef attrName, AsmParser &parser, Attribute &resultAttr) const {
495   DynamicAttrDefinition *attrDef = lookupAttrDefinition(attrName);
496   if (!attrDef)
497     return llvm::None;
498 
499   DynamicAttr dynAttr;
500   if (DynamicAttr::parse(parser, attrDef, dynAttr))
501     return failure();
502   resultAttr = dynAttr;
503   return success();
504 }
505 
506 LogicalResult ExtensibleDialect::printIfDynamicAttr(Attribute attribute,
507                                                     AsmPrinter &printer) {
508   if (auto dynAttr = attribute.dyn_cast<DynamicAttr>()) {
509     dynAttr.print(printer);
510     return success();
511   }
512   return failure();
513 }
514