1 //===- ExtensibleDialect.cpp - Extensible dialect ---------------*- C++ -*-===//
2 //
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/ExtensibleDialect.h"
10 #include "mlir/IR/AttributeSupport.h"
11 #include "mlir/IR/DialectImplementation.h"
12 #include "mlir/IR/OperationSupport.h"
13 #include "mlir/IR/StorageUniquerSupport.h"
14 #include "mlir/Support/LogicalResult.h"
15 
16 using namespace mlir;
17 
18 //===----------------------------------------------------------------------===//
19 // Dynamic types and attributes shared functions
20 //===----------------------------------------------------------------------===//
21 
22 /// Default parser for dynamic attribute or type parameters.
23 /// Parse in the format '(<>)?' or '<attr (,attr)*>'.
24 static LogicalResult
25 typeOrAttrParser(AsmParser &parser, SmallVectorImpl<Attribute> &parsedParams) {
26   // No parameters
27   if (parser.parseOptionalLess() || !parser.parseOptionalGreater())
28     return success();
29 
30   Attribute attr;
31   if (parser.parseAttribute(attr))
32     return failure();
33   parsedParams.push_back(attr);
34 
35   while (parser.parseOptionalGreater()) {
36     Attribute attr;
37     if (parser.parseComma() || parser.parseAttribute(attr))
38       return failure();
39     parsedParams.push_back(attr);
40   }
41 
42   return success();
43 }
44 
45 /// Default printer for dynamic attribute or type parameters.
46 /// Print in the format '(<>)?' or '<attr (,attr)*>'.
47 static void typeOrAttrPrinter(AsmPrinter &printer, ArrayRef<Attribute> params) {
48   if (params.empty())
49     return;
50 
51   printer << "<";
52   interleaveComma(params, printer.getStream());
53   printer << ">";
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // Dynamic type
58 //===----------------------------------------------------------------------===//
59 
60 std::unique_ptr<DynamicTypeDefinition>
61 DynamicTypeDefinition::get(StringRef name, ExtensibleDialect *dialect,
62                            VerifierFn &&verifier) {
63   return DynamicTypeDefinition::get(name, dialect, std::move(verifier),
64                                     typeOrAttrParser, typeOrAttrPrinter);
65 }
66 
67 std::unique_ptr<DynamicTypeDefinition>
68 DynamicTypeDefinition::get(StringRef name, ExtensibleDialect *dialect,
69                            VerifierFn &&verifier, ParserFn &&parser,
70                            PrinterFn &&printer) {
71   return std::unique_ptr<DynamicTypeDefinition>(
72       new DynamicTypeDefinition(name, dialect, std::move(verifier),
73                                 std::move(parser), std::move(printer)));
74 }
75 
76 DynamicTypeDefinition::DynamicTypeDefinition(StringRef nameRef,
77                                              ExtensibleDialect *dialect,
78                                              VerifierFn &&verifier,
79                                              ParserFn &&parser,
80                                              PrinterFn &&printer)
81     : name(nameRef), dialect(dialect), verifier(std::move(verifier)),
82       parser(std::move(parser)), printer(std::move(printer)),
83       ctx(dialect->getContext()) {}
84 
85 DynamicTypeDefinition::DynamicTypeDefinition(ExtensibleDialect *dialect,
86                                              StringRef nameRef)
87     : name(nameRef), dialect(dialect), ctx(dialect->getContext()) {}
88 
89 void DynamicTypeDefinition::registerInTypeUniquer() {
90   detail::TypeUniquer::registerType<DynamicType>(&getContext(), getTypeID());
91 }
92 
93 namespace mlir {
94 namespace detail {
95 /// Storage of DynamicType.
96 /// Contains a pointer to the type definition and type parameters.
97 struct DynamicTypeStorage : public TypeStorage {
98 
99   using KeyTy = std::pair<DynamicTypeDefinition *, ArrayRef<Attribute>>;
100 
101   explicit DynamicTypeStorage(DynamicTypeDefinition *typeDef,
102                               ArrayRef<Attribute> params)
103       : typeDef(typeDef), params(params) {}
104 
105   bool operator==(const KeyTy &key) const {
106     return typeDef == key.first && params == key.second;
107   }
108 
109   static llvm::hash_code hashKey(const KeyTy &key) {
110     return llvm::hash_value(key);
111   }
112 
113   static DynamicTypeStorage *construct(TypeStorageAllocator &alloc,
114                                        const KeyTy &key) {
115     return new (alloc.allocate<DynamicTypeStorage>())
116         DynamicTypeStorage(key.first, alloc.copyInto(key.second));
117   }
118 
119   /// Definition of the type.
120   DynamicTypeDefinition *typeDef;
121 
122   /// The type parameters.
123   ArrayRef<Attribute> params;
124 };
125 } // namespace detail
126 } // namespace mlir
127 
128 DynamicType DynamicType::get(DynamicTypeDefinition *typeDef,
129                              ArrayRef<Attribute> params) {
130   auto &ctx = typeDef->getContext();
131   auto emitError = detail::getDefaultDiagnosticEmitFn(&ctx);
132   assert(succeeded(typeDef->verify(emitError, params)));
133   return detail::TypeUniquer::getWithTypeID<DynamicType>(
134       &ctx, typeDef->getTypeID(), typeDef, params);
135 }
136 
137 DynamicType
138 DynamicType::getChecked(function_ref<InFlightDiagnostic()> emitError,
139                         DynamicTypeDefinition *typeDef,
140                         ArrayRef<Attribute> params) {
141   if (failed(typeDef->verify(emitError, params)))
142     return {};
143   auto &ctx = typeDef->getContext();
144   return detail::TypeUniquer::getWithTypeID<DynamicType>(
145       &ctx, typeDef->getTypeID(), typeDef, params);
146 }
147 
148 DynamicTypeDefinition *DynamicType::getTypeDef() { return getImpl()->typeDef; }
149 
150 ArrayRef<Attribute> DynamicType::getParams() { return getImpl()->params; }
151 
152 bool DynamicType::classof(Type type) {
153   return type.hasTrait<TypeTrait::IsDynamicType>();
154 }
155 
156 ParseResult DynamicType::parse(AsmParser &parser,
157                                DynamicTypeDefinition *typeDef,
158                                DynamicType &parsedType) {
159   SmallVector<Attribute> params;
160   if (failed(typeDef->parser(parser, params)))
161     return failure();
162   parsedType = parser.getChecked<DynamicType>(typeDef, params);
163   if (!parsedType)
164     return failure();
165   return success();
166 }
167 
168 void DynamicType::print(AsmPrinter &printer) {
169   printer << getTypeDef()->getName();
170   getTypeDef()->printer(printer, getParams());
171 }
172 
173 //===----------------------------------------------------------------------===//
174 // Dynamic attribute
175 //===----------------------------------------------------------------------===//
176 
177 std::unique_ptr<DynamicAttrDefinition>
178 DynamicAttrDefinition::get(StringRef name, ExtensibleDialect *dialect,
179                            VerifierFn &&verifier) {
180   return DynamicAttrDefinition::get(name, dialect, std::move(verifier),
181                                     typeOrAttrParser, typeOrAttrPrinter);
182 }
183 
184 std::unique_ptr<DynamicAttrDefinition>
185 DynamicAttrDefinition::get(StringRef name, ExtensibleDialect *dialect,
186                            VerifierFn &&verifier, ParserFn &&parser,
187                            PrinterFn &&printer) {
188   return std::unique_ptr<DynamicAttrDefinition>(
189       new DynamicAttrDefinition(name, dialect, std::move(verifier),
190                                 std::move(parser), std::move(printer)));
191 }
192 
193 DynamicAttrDefinition::DynamicAttrDefinition(StringRef nameRef,
194                                              ExtensibleDialect *dialect,
195                                              VerifierFn &&verifier,
196                                              ParserFn &&parser,
197                                              PrinterFn &&printer)
198     : name(nameRef), dialect(dialect), verifier(std::move(verifier)),
199       parser(std::move(parser)), printer(std::move(printer)),
200       ctx(dialect->getContext()) {}
201 
202 DynamicAttrDefinition::DynamicAttrDefinition(ExtensibleDialect *dialect,
203                                              StringRef nameRef)
204     : name(nameRef), dialect(dialect), ctx(dialect->getContext()) {}
205 
206 void DynamicAttrDefinition::registerInAttrUniquer() {
207   detail::AttributeUniquer::registerAttribute<DynamicAttr>(&getContext(),
208                                                            getTypeID());
209 }
210 
211 namespace mlir {
212 namespace detail {
213 /// Storage of DynamicAttr.
214 /// Contains a pointer to the attribute definition and attribute parameters.
215 struct DynamicAttrStorage : public AttributeStorage {
216   using KeyTy = std::pair<DynamicAttrDefinition *, ArrayRef<Attribute>>;
217 
218   explicit DynamicAttrStorage(DynamicAttrDefinition *attrDef,
219                               ArrayRef<Attribute> params)
220       : attrDef(attrDef), params(params) {}
221 
222   bool operator==(const KeyTy &key) const {
223     return attrDef == key.first && params == key.second;
224   }
225 
226   static llvm::hash_code hashKey(const KeyTy &key) {
227     return llvm::hash_value(key);
228   }
229 
230   static DynamicAttrStorage *construct(AttributeStorageAllocator &alloc,
231                                        const KeyTy &key) {
232     return new (alloc.allocate<DynamicAttrStorage>())
233         DynamicAttrStorage(key.first, alloc.copyInto(key.second));
234   }
235 
236   /// Definition of the type.
237   DynamicAttrDefinition *attrDef;
238 
239   /// The type parameters.
240   ArrayRef<Attribute> params;
241 };
242 } // namespace detail
243 } // namespace mlir
244 
245 DynamicAttr DynamicAttr::get(DynamicAttrDefinition *attrDef,
246                              ArrayRef<Attribute> params) {
247   auto &ctx = attrDef->getContext();
248   return detail::AttributeUniquer::getWithTypeID<DynamicAttr>(
249       &ctx, attrDef->getTypeID(), attrDef, params);
250 }
251 
252 DynamicAttr
253 DynamicAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
254                         DynamicAttrDefinition *attrDef,
255                         ArrayRef<Attribute> params) {
256   if (failed(attrDef->verify(emitError, params)))
257     return {};
258   return get(attrDef, params);
259 }
260 
261 DynamicAttrDefinition *DynamicAttr::getAttrDef() { return getImpl()->attrDef; }
262 
263 ArrayRef<Attribute> DynamicAttr::getParams() { return getImpl()->params; }
264 
265 bool DynamicAttr::classof(Attribute attr) {
266   return attr.hasTrait<AttributeTrait::IsDynamicAttr>();
267 }
268 
269 ParseResult DynamicAttr::parse(AsmParser &parser,
270                                DynamicAttrDefinition *attrDef,
271                                DynamicAttr &parsedAttr) {
272   SmallVector<Attribute> params;
273   if (failed(attrDef->parser(parser, params)))
274     return failure();
275   parsedAttr = parser.getChecked<DynamicAttr>(attrDef, params);
276   if (!parsedAttr)
277     return failure();
278   return success();
279 }
280 
281 void DynamicAttr::print(AsmPrinter &printer) {
282   printer << getAttrDef()->getName();
283   getAttrDef()->printer(printer, getParams());
284 }
285 
286 //===----------------------------------------------------------------------===//
287 // Dynamic operation
288 //===----------------------------------------------------------------------===//
289 
290 DynamicOpDefinition::DynamicOpDefinition(
291     StringRef name, ExtensibleDialect *dialect,
292     OperationName::VerifyInvariantsFn &&verifyFn,
293     OperationName::VerifyRegionInvariantsFn &&verifyRegionFn,
294     OperationName::ParseAssemblyFn &&parseFn,
295     OperationName::PrintAssemblyFn &&printFn,
296     OperationName::FoldHookFn &&foldHookFn,
297     OperationName::GetCanonicalizationPatternsFn
298         &&getCanonicalizationPatternsFn)
299     : typeID(dialect->allocateTypeID()),
300       name((dialect->getNamespace() + "." + name).str()), dialect(dialect),
301       verifyFn(std::move(verifyFn)), verifyRegionFn(std::move(verifyRegionFn)),
302       parseFn(std::move(parseFn)), printFn(std::move(printFn)),
303       foldHookFn(std::move(foldHookFn)),
304       getCanonicalizationPatternsFn(std::move(getCanonicalizationPatternsFn)) {}
305 
306 std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get(
307     StringRef name, ExtensibleDialect *dialect,
308     OperationName::VerifyInvariantsFn &&verifyFn,
309     OperationName::VerifyRegionInvariantsFn &&verifyRegionFn) {
310   auto parseFn = [](OpAsmParser &parser, OperationState &result) {
311     return parser.emitError(
312         parser.getCurrentLocation(),
313         "dynamic operation do not define any parser function");
314   };
315 
316   auto printFn = [](Operation *op, OpAsmPrinter &printer, StringRef) {
317     printer.printGenericOp(op);
318   };
319 
320   return DynamicOpDefinition::get(name, dialect, std::move(verifyFn),
321                                   std::move(verifyRegionFn), std::move(parseFn),
322                                   std::move(printFn));
323 }
324 
325 std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get(
326     StringRef name, ExtensibleDialect *dialect,
327     OperationName::VerifyInvariantsFn &&verifyFn,
328     OperationName::VerifyRegionInvariantsFn &&verifyRegionFn,
329     OperationName::ParseAssemblyFn &&parseFn,
330     OperationName::PrintAssemblyFn &&printFn) {
331   auto foldHookFn = [](Operation *op, ArrayRef<Attribute> operands,
332                        SmallVectorImpl<OpFoldResult> &results) {
333     return failure();
334   };
335 
336   auto getCanonicalizationPatternsFn = [](RewritePatternSet &, MLIRContext *) {
337   };
338 
339   return DynamicOpDefinition::get(name, dialect, std::move(verifyFn),
340                                   std::move(verifyRegionFn), std::move(parseFn),
341                                   std::move(printFn), std::move(foldHookFn),
342                                   std::move(getCanonicalizationPatternsFn));
343 }
344 
345 std::unique_ptr<DynamicOpDefinition>
346 DynamicOpDefinition::get(StringRef name, ExtensibleDialect *dialect,
347                          OperationName::VerifyInvariantsFn &&verifyFn,
348                          OperationName::VerifyInvariantsFn &&verifyRegionFn,
349                          OperationName::ParseAssemblyFn &&parseFn,
350                          OperationName::PrintAssemblyFn &&printFn,
351                          OperationName::FoldHookFn &&foldHookFn,
352                          OperationName::GetCanonicalizationPatternsFn
353                              &&getCanonicalizationPatternsFn) {
354   return std::unique_ptr<DynamicOpDefinition>(new DynamicOpDefinition(
355       name, dialect, std::move(verifyFn), std::move(verifyRegionFn),
356       std::move(parseFn), std::move(printFn), std::move(foldHookFn),
357       std::move(getCanonicalizationPatternsFn)));
358 }
359 
360 //===----------------------------------------------------------------------===//
361 // Extensible dialect
362 //===----------------------------------------------------------------------===//
363 
364 namespace {
365 /// Interface that can only be implemented by extensible dialects.
366 /// The interface is used to check if a dialect is extensible or not.
367 class IsExtensibleDialect : public DialectInterface::Base<IsExtensibleDialect> {
368 public:
369   IsExtensibleDialect(Dialect *dialect) : Base(dialect) {}
370 
371   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(IsExtensibleDialect)
372 };
373 } // namespace
374 
375 ExtensibleDialect::ExtensibleDialect(StringRef name, MLIRContext *ctx,
376                                      TypeID typeID)
377     : Dialect(name, ctx, typeID) {
378   addInterfaces<IsExtensibleDialect>();
379 }
380 
381 void ExtensibleDialect::registerDynamicType(
382     std::unique_ptr<DynamicTypeDefinition> &&type) {
383   DynamicTypeDefinition *typePtr = type.get();
384   TypeID typeID = type->getTypeID();
385   StringRef name = type->getName();
386   ExtensibleDialect *dialect = type->getDialect();
387 
388   assert(dialect == this &&
389          "trying to register a dynamic type in the wrong dialect");
390 
391   // If a type with the same name is already defined, fail.
392   auto registered = dynTypes.try_emplace(typeID, std::move(type)).second;
393   (void)registered;
394   assert(registered && "type TypeID was not unique");
395 
396   registered = nameToDynTypes.insert({name, typePtr}).second;
397   (void)registered;
398   assert(registered &&
399          "Trying to create a new dynamic type with an existing name");
400 
401   auto abstractType =
402       AbstractType::get(*dialect, DynamicAttr::getInterfaceMap(),
403                         DynamicType::getHasTraitFn(), typeID);
404 
405   /// Add the type to the dialect and the type uniquer.
406   addType(typeID, std::move(abstractType));
407   typePtr->registerInTypeUniquer();
408 }
409 
410 void ExtensibleDialect::registerDynamicAttr(
411     std::unique_ptr<DynamicAttrDefinition> &&attr) {
412   auto *attrPtr = attr.get();
413   auto typeID = attr->getTypeID();
414   auto name = attr->getName();
415   auto *dialect = attr->getDialect();
416 
417   assert(dialect == this &&
418          "trying to register a dynamic attribute in the wrong dialect");
419 
420   // If an attribute with the same name is already defined, fail.
421   auto registered = dynAttrs.try_emplace(typeID, std::move(attr)).second;
422   (void)registered;
423   assert(registered && "attribute TypeID was not unique");
424 
425   registered = nameToDynAttrs.insert({name, attrPtr}).second;
426   (void)registered;
427   assert(registered &&
428          "Trying to create a new dynamic attribute with an existing name");
429 
430   auto abstractAttr =
431       AbstractAttribute::get(*dialect, DynamicAttr::getInterfaceMap(),
432                              DynamicAttr::getHasTraitFn(), typeID);
433 
434   /// Add the type to the dialect and the type uniquer.
435   addAttribute(typeID, std::move(abstractAttr));
436   attrPtr->registerInAttrUniquer();
437 }
438 
439 void ExtensibleDialect::registerDynamicOp(
440     std::unique_ptr<DynamicOpDefinition> &&op) {
441   assert(op->dialect == this &&
442          "trying to register a dynamic op in the wrong dialect");
443   auto hasTraitFn = [](TypeID traitId) { return false; };
444 
445   RegisteredOperationName::insert(
446       op->name, *op->dialect, op->typeID, std::move(op->parseFn),
447       std::move(op->printFn), std::move(op->verifyFn),
448       std::move(op->verifyRegionFn), std::move(op->foldHookFn),
449       std::move(op->getCanonicalizationPatternsFn),
450       detail::InterfaceMap::get<>(), std::move(hasTraitFn), {});
451 }
452 
453 bool ExtensibleDialect::classof(const Dialect *dialect) {
454   return const_cast<Dialect *>(dialect)
455       ->getRegisteredInterface<IsExtensibleDialect>();
456 }
457 
458 OptionalParseResult ExtensibleDialect::parseOptionalDynamicType(
459     StringRef typeName, AsmParser &parser, Type &resultType) const {
460   DynamicTypeDefinition *typeDef = lookupTypeDefinition(typeName);
461   if (!typeDef)
462     return llvm::None;
463 
464   DynamicType dynType;
465   if (DynamicType::parse(parser, typeDef, dynType))
466     return failure();
467   resultType = dynType;
468   return success();
469 }
470 
471 LogicalResult ExtensibleDialect::printIfDynamicType(Type type,
472                                                     AsmPrinter &printer) {
473   if (auto dynType = type.dyn_cast<DynamicType>()) {
474     dynType.print(printer);
475     return success();
476   }
477   return failure();
478 }
479 
480 OptionalParseResult ExtensibleDialect::parseOptionalDynamicAttr(
481     StringRef attrName, AsmParser &parser, Attribute &resultAttr) const {
482   DynamicAttrDefinition *attrDef = lookupAttrDefinition(attrName);
483   if (!attrDef)
484     return llvm::None;
485 
486   DynamicAttr dynAttr;
487   if (DynamicAttr::parse(parser, attrDef, dynAttr))
488     return failure();
489   resultAttr = dynAttr;
490   return success();
491 }
492 
493 LogicalResult ExtensibleDialect::printIfDynamicAttr(Attribute attribute,
494                                                     AsmPrinter &printer) {
495   if (auto dynAttr = attribute.dyn_cast<DynamicAttr>()) {
496     dynAttr.print(printer);
497     return success();
498   }
499   return failure();
500 }
501