1 //===- ExtensibleDialect.cpp - Extensible dialect ---------------*- C++ -*-===// 2 // 3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/ExtensibleDialect.h" 10 #include "mlir/IR/AttributeSupport.h" 11 #include "mlir/IR/DialectImplementation.h" 12 #include "mlir/IR/OperationSupport.h" 13 #include "mlir/IR/StorageUniquerSupport.h" 14 #include "mlir/Support/LogicalResult.h" 15 16 using namespace mlir; 17 18 //===----------------------------------------------------------------------===// 19 // Dynamic types and attributes shared functions 20 //===----------------------------------------------------------------------===// 21 22 /// Default parser for dynamic attribute or type parameters. 23 /// Parse in the format '(<>)?' or '<attr (,attr)*>'. 24 static LogicalResult 25 typeOrAttrParser(AsmParser &parser, SmallVectorImpl<Attribute> &parsedParams) { 26 // No parameters 27 if (parser.parseOptionalLess() || !parser.parseOptionalGreater()) 28 return success(); 29 30 Attribute attr; 31 if (parser.parseAttribute(attr)) 32 return failure(); 33 parsedParams.push_back(attr); 34 35 while (parser.parseOptionalGreater()) { 36 Attribute attr; 37 if (parser.parseComma() || parser.parseAttribute(attr)) 38 return failure(); 39 parsedParams.push_back(attr); 40 } 41 42 return success(); 43 } 44 45 /// Default printer for dynamic attribute or type parameters. 46 /// Print in the format '(<>)?' or '<attr (,attr)*>'. 47 static void typeOrAttrPrinter(AsmPrinter &printer, ArrayRef<Attribute> params) { 48 if (params.empty()) 49 return; 50 51 printer << "<"; 52 interleaveComma(params, printer.getStream()); 53 printer << ">"; 54 } 55 56 //===----------------------------------------------------------------------===// 57 // Dynamic type 58 //===----------------------------------------------------------------------===// 59 60 std::unique_ptr<DynamicTypeDefinition> 61 DynamicTypeDefinition::get(StringRef name, ExtensibleDialect *dialect, 62 VerifierFn &&verifier) { 63 return DynamicTypeDefinition::get(name, dialect, std::move(verifier), 64 typeOrAttrParser, typeOrAttrPrinter); 65 } 66 67 std::unique_ptr<DynamicTypeDefinition> 68 DynamicTypeDefinition::get(StringRef name, ExtensibleDialect *dialect, 69 VerifierFn &&verifier, ParserFn &&parser, 70 PrinterFn &&printer) { 71 return std::unique_ptr<DynamicTypeDefinition>( 72 new DynamicTypeDefinition(name, dialect, std::move(verifier), 73 std::move(parser), std::move(printer))); 74 } 75 76 DynamicTypeDefinition::DynamicTypeDefinition(StringRef nameRef, 77 ExtensibleDialect *dialect, 78 VerifierFn &&verifier, 79 ParserFn &&parser, 80 PrinterFn &&printer) 81 : name(nameRef), dialect(dialect), verifier(std::move(verifier)), 82 parser(std::move(parser)), printer(std::move(printer)), 83 ctx(dialect->getContext()) {} 84 85 DynamicTypeDefinition::DynamicTypeDefinition(ExtensibleDialect *dialect, 86 StringRef nameRef) 87 : name(nameRef), dialect(dialect), ctx(dialect->getContext()) {} 88 89 void DynamicTypeDefinition::registerInTypeUniquer() { 90 detail::TypeUniquer::registerType<DynamicType>(&getContext(), getTypeID()); 91 } 92 93 namespace mlir { 94 namespace detail { 95 /// Storage of DynamicType. 96 /// Contains a pointer to the type definition and type parameters. 97 struct DynamicTypeStorage : public TypeStorage { 98 99 using KeyTy = std::pair<DynamicTypeDefinition *, ArrayRef<Attribute>>; 100 101 explicit DynamicTypeStorage(DynamicTypeDefinition *typeDef, 102 ArrayRef<Attribute> params) 103 : typeDef(typeDef), params(params) {} 104 105 bool operator==(const KeyTy &key) const { 106 return typeDef == key.first && params == key.second; 107 } 108 109 static llvm::hash_code hashKey(const KeyTy &key) { 110 return llvm::hash_value(key); 111 } 112 113 static DynamicTypeStorage *construct(TypeStorageAllocator &alloc, 114 const KeyTy &key) { 115 return new (alloc.allocate<DynamicTypeStorage>()) 116 DynamicTypeStorage(key.first, alloc.copyInto(key.second)); 117 } 118 119 /// Definition of the type. 120 DynamicTypeDefinition *typeDef; 121 122 /// The type parameters. 123 ArrayRef<Attribute> params; 124 }; 125 } // namespace detail 126 } // namespace mlir 127 128 DynamicType DynamicType::get(DynamicTypeDefinition *typeDef, 129 ArrayRef<Attribute> params) { 130 auto &ctx = typeDef->getContext(); 131 auto emitError = detail::getDefaultDiagnosticEmitFn(&ctx); 132 assert(succeeded(typeDef->verify(emitError, params))); 133 return detail::TypeUniquer::getWithTypeID<DynamicType>( 134 &ctx, typeDef->getTypeID(), typeDef, params); 135 } 136 137 DynamicType 138 DynamicType::getChecked(function_ref<InFlightDiagnostic()> emitError, 139 DynamicTypeDefinition *typeDef, 140 ArrayRef<Attribute> params) { 141 if (failed(typeDef->verify(emitError, params))) 142 return {}; 143 auto &ctx = typeDef->getContext(); 144 return detail::TypeUniquer::getWithTypeID<DynamicType>( 145 &ctx, typeDef->getTypeID(), typeDef, params); 146 } 147 148 DynamicTypeDefinition *DynamicType::getTypeDef() { return getImpl()->typeDef; } 149 150 ArrayRef<Attribute> DynamicType::getParams() { return getImpl()->params; } 151 152 bool DynamicType::classof(Type type) { 153 return type.hasTrait<TypeTrait::IsDynamicType>(); 154 } 155 156 ParseResult DynamicType::parse(AsmParser &parser, 157 DynamicTypeDefinition *typeDef, 158 DynamicType &parsedType) { 159 SmallVector<Attribute> params; 160 if (failed(typeDef->parser(parser, params))) 161 return failure(); 162 parsedType = parser.getChecked<DynamicType>(typeDef, params); 163 if (!parsedType) 164 return failure(); 165 return success(); 166 } 167 168 void DynamicType::print(AsmPrinter &printer) { 169 printer << getTypeDef()->getName(); 170 getTypeDef()->printer(printer, getParams()); 171 } 172 173 //===----------------------------------------------------------------------===// 174 // Dynamic attribute 175 //===----------------------------------------------------------------------===// 176 177 std::unique_ptr<DynamicAttrDefinition> 178 DynamicAttrDefinition::get(StringRef name, ExtensibleDialect *dialect, 179 VerifierFn &&verifier) { 180 return DynamicAttrDefinition::get(name, dialect, std::move(verifier), 181 typeOrAttrParser, typeOrAttrPrinter); 182 } 183 184 std::unique_ptr<DynamicAttrDefinition> 185 DynamicAttrDefinition::get(StringRef name, ExtensibleDialect *dialect, 186 VerifierFn &&verifier, ParserFn &&parser, 187 PrinterFn &&printer) { 188 return std::unique_ptr<DynamicAttrDefinition>( 189 new DynamicAttrDefinition(name, dialect, std::move(verifier), 190 std::move(parser), std::move(printer))); 191 } 192 193 DynamicAttrDefinition::DynamicAttrDefinition(StringRef nameRef, 194 ExtensibleDialect *dialect, 195 VerifierFn &&verifier, 196 ParserFn &&parser, 197 PrinterFn &&printer) 198 : name(nameRef), dialect(dialect), verifier(std::move(verifier)), 199 parser(std::move(parser)), printer(std::move(printer)), 200 ctx(dialect->getContext()) {} 201 202 DynamicAttrDefinition::DynamicAttrDefinition(ExtensibleDialect *dialect, 203 StringRef nameRef) 204 : name(nameRef), dialect(dialect), ctx(dialect->getContext()) {} 205 206 void DynamicAttrDefinition::registerInAttrUniquer() { 207 detail::AttributeUniquer::registerAttribute<DynamicAttr>(&getContext(), 208 getTypeID()); 209 } 210 211 namespace mlir { 212 namespace detail { 213 /// Storage of DynamicAttr. 214 /// Contains a pointer to the attribute definition and attribute parameters. 215 struct DynamicAttrStorage : public AttributeStorage { 216 using KeyTy = std::pair<DynamicAttrDefinition *, ArrayRef<Attribute>>; 217 218 explicit DynamicAttrStorage(DynamicAttrDefinition *attrDef, 219 ArrayRef<Attribute> params) 220 : attrDef(attrDef), params(params) {} 221 222 bool operator==(const KeyTy &key) const { 223 return attrDef == key.first && params == key.second; 224 } 225 226 static llvm::hash_code hashKey(const KeyTy &key) { 227 return llvm::hash_value(key); 228 } 229 230 static DynamicAttrStorage *construct(AttributeStorageAllocator &alloc, 231 const KeyTy &key) { 232 return new (alloc.allocate<DynamicAttrStorage>()) 233 DynamicAttrStorage(key.first, alloc.copyInto(key.second)); 234 } 235 236 /// Definition of the type. 237 DynamicAttrDefinition *attrDef; 238 239 /// The type parameters. 240 ArrayRef<Attribute> params; 241 }; 242 } // namespace detail 243 } // namespace mlir 244 245 DynamicAttr DynamicAttr::get(DynamicAttrDefinition *attrDef, 246 ArrayRef<Attribute> params) { 247 auto &ctx = attrDef->getContext(); 248 return detail::AttributeUniquer::getWithTypeID<DynamicAttr>( 249 &ctx, attrDef->getTypeID(), attrDef, params); 250 } 251 252 DynamicAttr 253 DynamicAttr::getChecked(function_ref<InFlightDiagnostic()> emitError, 254 DynamicAttrDefinition *attrDef, 255 ArrayRef<Attribute> params) { 256 if (failed(attrDef->verify(emitError, params))) 257 return {}; 258 return get(attrDef, params); 259 } 260 261 DynamicAttrDefinition *DynamicAttr::getAttrDef() { return getImpl()->attrDef; } 262 263 ArrayRef<Attribute> DynamicAttr::getParams() { return getImpl()->params; } 264 265 bool DynamicAttr::classof(Attribute attr) { 266 return attr.hasTrait<AttributeTrait::IsDynamicAttr>(); 267 } 268 269 ParseResult DynamicAttr::parse(AsmParser &parser, 270 DynamicAttrDefinition *attrDef, 271 DynamicAttr &parsedAttr) { 272 SmallVector<Attribute> params; 273 if (failed(attrDef->parser(parser, params))) 274 return failure(); 275 parsedAttr = parser.getChecked<DynamicAttr>(attrDef, params); 276 if (!parsedAttr) 277 return failure(); 278 return success(); 279 } 280 281 void DynamicAttr::print(AsmPrinter &printer) { 282 printer << getAttrDef()->getName(); 283 getAttrDef()->printer(printer, getParams()); 284 } 285 286 //===----------------------------------------------------------------------===// 287 // Dynamic operation 288 //===----------------------------------------------------------------------===// 289 290 DynamicOpDefinition::DynamicOpDefinition( 291 StringRef name, ExtensibleDialect *dialect, 292 OperationName::VerifyInvariantsFn &&verifyFn, 293 OperationName::VerifyRegionInvariantsFn &&verifyRegionFn, 294 OperationName::ParseAssemblyFn &&parseFn, 295 OperationName::PrintAssemblyFn &&printFn, 296 OperationName::FoldHookFn &&foldHookFn, 297 OperationName::GetCanonicalizationPatternsFn 298 &&getCanonicalizationPatternsFn, 299 OperationName::PopulateDefaultAttrsFn &&populateDefaultAttrsFn) 300 : typeID(dialect->allocateTypeID()), 301 name((dialect->getNamespace() + "." + name).str()), dialect(dialect), 302 verifyFn(std::move(verifyFn)), verifyRegionFn(std::move(verifyRegionFn)), 303 parseFn(std::move(parseFn)), printFn(std::move(printFn)), 304 foldHookFn(std::move(foldHookFn)), 305 getCanonicalizationPatternsFn(std::move(getCanonicalizationPatternsFn)), 306 populateDefaultAttrsFn(std::move(populateDefaultAttrsFn)) {} 307 308 std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get( 309 StringRef name, ExtensibleDialect *dialect, 310 OperationName::VerifyInvariantsFn &&verifyFn, 311 OperationName::VerifyRegionInvariantsFn &&verifyRegionFn) { 312 auto parseFn = [](OpAsmParser &parser, OperationState &result) { 313 return parser.emitError( 314 parser.getCurrentLocation(), 315 "dynamic operation do not define any parser function"); 316 }; 317 318 auto printFn = [](Operation *op, OpAsmPrinter &printer, StringRef) { 319 printer.printGenericOp(op); 320 }; 321 322 return DynamicOpDefinition::get(name, dialect, std::move(verifyFn), 323 std::move(verifyRegionFn), std::move(parseFn), 324 std::move(printFn)); 325 } 326 327 std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get( 328 StringRef name, ExtensibleDialect *dialect, 329 OperationName::VerifyInvariantsFn &&verifyFn, 330 OperationName::VerifyRegionInvariantsFn &&verifyRegionFn, 331 OperationName::ParseAssemblyFn &&parseFn, 332 OperationName::PrintAssemblyFn &&printFn) { 333 auto foldHookFn = [](Operation *op, ArrayRef<Attribute> operands, 334 SmallVectorImpl<OpFoldResult> &results) { 335 return failure(); 336 }; 337 338 auto getCanonicalizationPatternsFn = [](RewritePatternSet &, MLIRContext *) { 339 }; 340 341 auto populateDefaultAttrsFn = [](const RegisteredOperationName &, 342 NamedAttrList &) {}; 343 344 return DynamicOpDefinition::get(name, dialect, std::move(verifyFn), 345 std::move(verifyRegionFn), std::move(parseFn), 346 std::move(printFn), std::move(foldHookFn), 347 std::move(getCanonicalizationPatternsFn), 348 std::move(populateDefaultAttrsFn)); 349 } 350 351 std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get( 352 StringRef name, ExtensibleDialect *dialect, 353 OperationName::VerifyInvariantsFn &&verifyFn, 354 OperationName::VerifyInvariantsFn &&verifyRegionFn, 355 OperationName::ParseAssemblyFn &&parseFn, 356 OperationName::PrintAssemblyFn &&printFn, 357 OperationName::FoldHookFn &&foldHookFn, 358 OperationName::GetCanonicalizationPatternsFn 359 &&getCanonicalizationPatternsFn, 360 OperationName::PopulateDefaultAttrsFn &&populateDefaultAttrsFn) { 361 return std::unique_ptr<DynamicOpDefinition>(new DynamicOpDefinition( 362 name, dialect, std::move(verifyFn), std::move(verifyRegionFn), 363 std::move(parseFn), std::move(printFn), std::move(foldHookFn), 364 std::move(getCanonicalizationPatternsFn), 365 std::move(populateDefaultAttrsFn))); 366 } 367 368 //===----------------------------------------------------------------------===// 369 // Extensible dialect 370 //===----------------------------------------------------------------------===// 371 372 namespace { 373 /// Interface that can only be implemented by extensible dialects. 374 /// The interface is used to check if a dialect is extensible or not. 375 class IsExtensibleDialect : public DialectInterface::Base<IsExtensibleDialect> { 376 public: 377 IsExtensibleDialect(Dialect *dialect) : Base(dialect) {} 378 379 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(IsExtensibleDialect) 380 }; 381 } // namespace 382 383 ExtensibleDialect::ExtensibleDialect(StringRef name, MLIRContext *ctx, 384 TypeID typeID) 385 : Dialect(name, ctx, typeID) { 386 addInterfaces<IsExtensibleDialect>(); 387 } 388 389 void ExtensibleDialect::registerDynamicType( 390 std::unique_ptr<DynamicTypeDefinition> &&type) { 391 DynamicTypeDefinition *typePtr = type.get(); 392 TypeID typeID = type->getTypeID(); 393 StringRef name = type->getName(); 394 ExtensibleDialect *dialect = type->getDialect(); 395 396 assert(dialect == this && 397 "trying to register a dynamic type in the wrong dialect"); 398 399 // If a type with the same name is already defined, fail. 400 auto registered = dynTypes.try_emplace(typeID, std::move(type)).second; 401 (void)registered; 402 assert(registered && "type TypeID was not unique"); 403 404 registered = nameToDynTypes.insert({name, typePtr}).second; 405 (void)registered; 406 assert(registered && 407 "Trying to create a new dynamic type with an existing name"); 408 409 auto abstractType = 410 AbstractType::get(*dialect, DynamicAttr::getInterfaceMap(), 411 DynamicType::getHasTraitFn(), typeID); 412 413 /// Add the type to the dialect and the type uniquer. 414 addType(typeID, std::move(abstractType)); 415 typePtr->registerInTypeUniquer(); 416 } 417 418 void ExtensibleDialect::registerDynamicAttr( 419 std::unique_ptr<DynamicAttrDefinition> &&attr) { 420 auto *attrPtr = attr.get(); 421 auto typeID = attr->getTypeID(); 422 auto name = attr->getName(); 423 auto *dialect = attr->getDialect(); 424 425 assert(dialect == this && 426 "trying to register a dynamic attribute in the wrong dialect"); 427 428 // If an attribute with the same name is already defined, fail. 429 auto registered = dynAttrs.try_emplace(typeID, std::move(attr)).second; 430 (void)registered; 431 assert(registered && "attribute TypeID was not unique"); 432 433 registered = nameToDynAttrs.insert({name, attrPtr}).second; 434 (void)registered; 435 assert(registered && 436 "Trying to create a new dynamic attribute with an existing name"); 437 438 auto abstractAttr = 439 AbstractAttribute::get(*dialect, DynamicAttr::getInterfaceMap(), 440 DynamicAttr::getHasTraitFn(), typeID); 441 442 /// Add the type to the dialect and the type uniquer. 443 addAttribute(typeID, std::move(abstractAttr)); 444 attrPtr->registerInAttrUniquer(); 445 } 446 447 void ExtensibleDialect::registerDynamicOp( 448 std::unique_ptr<DynamicOpDefinition> &&op) { 449 assert(op->dialect == this && 450 "trying to register a dynamic op in the wrong dialect"); 451 auto hasTraitFn = [](TypeID traitId) { return false; }; 452 453 RegisteredOperationName::insert( 454 op->name, *op->dialect, op->typeID, std::move(op->parseFn), 455 std::move(op->printFn), std::move(op->verifyFn), 456 std::move(op->verifyRegionFn), std::move(op->foldHookFn), 457 std::move(op->getCanonicalizationPatternsFn), 458 detail::InterfaceMap::get<>(), std::move(hasTraitFn), {}, 459 std::move(op->populateDefaultAttrsFn)); 460 } 461 462 bool ExtensibleDialect::classof(const Dialect *dialect) { 463 return const_cast<Dialect *>(dialect) 464 ->getRegisteredInterface<IsExtensibleDialect>(); 465 } 466 467 OptionalParseResult ExtensibleDialect::parseOptionalDynamicType( 468 StringRef typeName, AsmParser &parser, Type &resultType) const { 469 DynamicTypeDefinition *typeDef = lookupTypeDefinition(typeName); 470 if (!typeDef) 471 return llvm::None; 472 473 DynamicType dynType; 474 if (DynamicType::parse(parser, typeDef, dynType)) 475 return failure(); 476 resultType = dynType; 477 return success(); 478 } 479 480 LogicalResult ExtensibleDialect::printIfDynamicType(Type type, 481 AsmPrinter &printer) { 482 if (auto dynType = type.dyn_cast<DynamicType>()) { 483 dynType.print(printer); 484 return success(); 485 } 486 return failure(); 487 } 488 489 OptionalParseResult ExtensibleDialect::parseOptionalDynamicAttr( 490 StringRef attrName, AsmParser &parser, Attribute &resultAttr) const { 491 DynamicAttrDefinition *attrDef = lookupAttrDefinition(attrName); 492 if (!attrDef) 493 return llvm::None; 494 495 DynamicAttr dynAttr; 496 if (DynamicAttr::parse(parser, attrDef, dynAttr)) 497 return failure(); 498 resultAttr = dynAttr; 499 return success(); 500 } 501 502 LogicalResult ExtensibleDialect::printIfDynamicAttr(Attribute attribute, 503 AsmPrinter &printer) { 504 if (auto dynAttr = attribute.dyn_cast<DynamicAttr>()) { 505 dynAttr.print(printer); 506 return success(); 507 } 508 return failure(); 509 } 510