1 //===- ExtensibleDialect.cpp - Extensible dialect ---------------*- C++ -*-===// 2 // 3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/ExtensibleDialect.h" 10 #include "mlir/IR/AttributeSupport.h" 11 #include "mlir/IR/DialectImplementation.h" 12 #include "mlir/IR/OperationSupport.h" 13 #include "mlir/IR/StorageUniquerSupport.h" 14 #include "mlir/Support/LogicalResult.h" 15 16 using namespace mlir; 17 18 //===----------------------------------------------------------------------===// 19 // Dynamic types and attributes shared functions 20 //===----------------------------------------------------------------------===// 21 22 /// Default parser for dynamic attribute or type parameters. 23 /// Parse in the format '(<>)?' or '<attr (,attr)*>'. 24 static LogicalResult 25 typeOrAttrParser(AsmParser &parser, SmallVectorImpl<Attribute> &parsedParams) { 26 // No parameters 27 if (parser.parseOptionalLess() || !parser.parseOptionalGreater()) 28 return success(); 29 30 Attribute attr; 31 if (parser.parseAttribute(attr)) 32 return failure(); 33 parsedParams.push_back(attr); 34 35 while (parser.parseOptionalGreater()) { 36 Attribute attr; 37 if (parser.parseComma() || parser.parseAttribute(attr)) 38 return failure(); 39 parsedParams.push_back(attr); 40 } 41 42 return success(); 43 } 44 45 /// Default printer for dynamic attribute or type parameters. 46 /// Print in the format '(<>)?' or '<attr (,attr)*>'. 47 static void typeOrAttrPrinter(AsmPrinter &printer, ArrayRef<Attribute> params) { 48 if (params.empty()) 49 return; 50 51 printer << "<"; 52 interleaveComma(params, printer.getStream()); 53 printer << ">"; 54 } 55 56 //===----------------------------------------------------------------------===// 57 // Dynamic type 58 //===----------------------------------------------------------------------===// 59 60 std::unique_ptr<DynamicTypeDefinition> 61 DynamicTypeDefinition::get(StringRef name, ExtensibleDialect *dialect, 62 VerifierFn &&verifier) { 63 return DynamicTypeDefinition::get(name, dialect, std::move(verifier), 64 typeOrAttrParser, typeOrAttrPrinter); 65 } 66 67 std::unique_ptr<DynamicTypeDefinition> 68 DynamicTypeDefinition::get(StringRef name, ExtensibleDialect *dialect, 69 VerifierFn &&verifier, ParserFn &&parser, 70 PrinterFn &&printer) { 71 return std::unique_ptr<DynamicTypeDefinition>( 72 new DynamicTypeDefinition(name, dialect, std::move(verifier), 73 std::move(parser), std::move(printer))); 74 } 75 76 DynamicTypeDefinition::DynamicTypeDefinition(StringRef nameRef, 77 ExtensibleDialect *dialect, 78 VerifierFn &&verifier, 79 ParserFn &&parser, 80 PrinterFn &&printer) 81 : name(nameRef), dialect(dialect), verifier(std::move(verifier)), 82 parser(std::move(parser)), printer(std::move(printer)), 83 ctx(dialect->getContext()) { 84 assert(!nameRef.contains('.') && 85 "name should not be prefixed by the dialect name"); 86 } 87 88 DynamicTypeDefinition::DynamicTypeDefinition(ExtensibleDialect *dialect, 89 StringRef nameRef) 90 : name(nameRef), dialect(dialect), ctx(dialect->getContext()) { 91 assert(!nameRef.contains('.') && 92 "name should not be prefixed by the dialect name"); 93 } 94 95 void DynamicTypeDefinition::registerInTypeUniquer() { 96 detail::TypeUniquer::registerType<DynamicType>(&getContext(), getTypeID()); 97 } 98 99 namespace mlir { 100 namespace detail { 101 /// Storage of DynamicType. 102 /// Contains a pointer to the type definition and type parameters. 103 struct DynamicTypeStorage : public TypeStorage { 104 105 using KeyTy = std::pair<DynamicTypeDefinition *, ArrayRef<Attribute>>; 106 107 explicit DynamicTypeStorage(DynamicTypeDefinition *typeDef, 108 ArrayRef<Attribute> params) 109 : typeDef(typeDef), params(params) {} 110 111 bool operator==(const KeyTy &key) const { 112 return typeDef == key.first && params == key.second; 113 } 114 115 static llvm::hash_code hashKey(const KeyTy &key) { 116 return llvm::hash_value(key); 117 } 118 119 static DynamicTypeStorage *construct(TypeStorageAllocator &alloc, 120 const KeyTy &key) { 121 return new (alloc.allocate<DynamicTypeStorage>()) 122 DynamicTypeStorage(key.first, alloc.copyInto(key.second)); 123 } 124 125 /// Definition of the type. 126 DynamicTypeDefinition *typeDef; 127 128 /// The type parameters. 129 ArrayRef<Attribute> params; 130 }; 131 } // namespace detail 132 } // namespace mlir 133 134 DynamicType DynamicType::get(DynamicTypeDefinition *typeDef, 135 ArrayRef<Attribute> params) { 136 auto &ctx = typeDef->getContext(); 137 auto emitError = detail::getDefaultDiagnosticEmitFn(&ctx); 138 assert(succeeded(typeDef->verify(emitError, params))); 139 return detail::TypeUniquer::getWithTypeID<DynamicType>( 140 &ctx, typeDef->getTypeID(), typeDef, params); 141 } 142 143 DynamicType 144 DynamicType::getChecked(function_ref<InFlightDiagnostic()> emitError, 145 DynamicTypeDefinition *typeDef, 146 ArrayRef<Attribute> params) { 147 if (failed(typeDef->verify(emitError, params))) 148 return {}; 149 auto &ctx = typeDef->getContext(); 150 return detail::TypeUniquer::getWithTypeID<DynamicType>( 151 &ctx, typeDef->getTypeID(), typeDef, params); 152 } 153 154 DynamicTypeDefinition *DynamicType::getTypeDef() { return getImpl()->typeDef; } 155 156 ArrayRef<Attribute> DynamicType::getParams() { return getImpl()->params; } 157 158 bool DynamicType::classof(Type type) { 159 return type.hasTrait<IsDynamicTypeTrait>(); 160 } 161 162 ParseResult DynamicType::parse(AsmParser &parser, 163 DynamicTypeDefinition *typeDef, 164 DynamicType &parsedType) { 165 SmallVector<Attribute> params; 166 if (failed(typeDef->parser(parser, params))) 167 return failure(); 168 parsedType = parser.getChecked<DynamicType>(typeDef, params); 169 if (!parsedType) 170 return failure(); 171 return success(); 172 } 173 174 void DynamicType::print(AsmPrinter &printer) { 175 printer << getTypeDef()->getName(); 176 getTypeDef()->printer(printer, getParams()); 177 } 178 179 //===----------------------------------------------------------------------===// 180 // Dynamic attribute 181 //===----------------------------------------------------------------------===// 182 183 std::unique_ptr<DynamicAttrDefinition> 184 DynamicAttrDefinition::get(StringRef name, ExtensibleDialect *dialect, 185 VerifierFn &&verifier) { 186 return DynamicAttrDefinition::get(name, dialect, std::move(verifier), 187 typeOrAttrParser, typeOrAttrPrinter); 188 } 189 190 std::unique_ptr<DynamicAttrDefinition> 191 DynamicAttrDefinition::get(StringRef name, ExtensibleDialect *dialect, 192 VerifierFn &&verifier, ParserFn &&parser, 193 PrinterFn &&printer) { 194 return std::unique_ptr<DynamicAttrDefinition>( 195 new DynamicAttrDefinition(name, dialect, std::move(verifier), 196 std::move(parser), std::move(printer))); 197 } 198 199 DynamicAttrDefinition::DynamicAttrDefinition(StringRef nameRef, 200 ExtensibleDialect *dialect, 201 VerifierFn &&verifier, 202 ParserFn &&parser, 203 PrinterFn &&printer) 204 : name(nameRef), dialect(dialect), verifier(std::move(verifier)), 205 parser(std::move(parser)), printer(std::move(printer)), 206 ctx(dialect->getContext()) { 207 assert(!nameRef.contains('.') && 208 "name should not be prefixed by the dialect name"); 209 } 210 211 DynamicAttrDefinition::DynamicAttrDefinition(ExtensibleDialect *dialect, 212 StringRef nameRef) 213 : name(nameRef), dialect(dialect), ctx(dialect->getContext()) { 214 assert(!nameRef.contains('.') && 215 "name should not be prefixed by the dialect name"); 216 } 217 218 void DynamicAttrDefinition::registerInAttrUniquer() { 219 detail::AttributeUniquer::registerAttribute<DynamicAttr>(&getContext(), 220 getTypeID()); 221 } 222 223 namespace mlir { 224 namespace detail { 225 /// Storage of DynamicAttr. 226 /// Contains a pointer to the attribute definition and attribute parameters. 227 struct DynamicAttrStorage : public AttributeStorage { 228 using KeyTy = std::pair<DynamicAttrDefinition *, ArrayRef<Attribute>>; 229 230 explicit DynamicAttrStorage(DynamicAttrDefinition *attrDef, 231 ArrayRef<Attribute> params) 232 : attrDef(attrDef), params(params) {} 233 234 bool operator==(const KeyTy &key) const { 235 return attrDef == key.first && params == key.second; 236 } 237 238 static llvm::hash_code hashKey(const KeyTy &key) { 239 return llvm::hash_value(key); 240 } 241 242 static DynamicAttrStorage *construct(AttributeStorageAllocator &alloc, 243 const KeyTy &key) { 244 return new (alloc.allocate<DynamicAttrStorage>()) 245 DynamicAttrStorage(key.first, alloc.copyInto(key.second)); 246 } 247 248 /// Definition of the type. 249 DynamicAttrDefinition *attrDef; 250 251 /// The type parameters. 252 ArrayRef<Attribute> params; 253 }; 254 } // namespace detail 255 } // namespace mlir 256 257 DynamicAttr DynamicAttr::get(DynamicAttrDefinition *attrDef, 258 ArrayRef<Attribute> params) { 259 auto &ctx = attrDef->getContext(); 260 return detail::AttributeUniquer::getWithTypeID<DynamicAttr>( 261 &ctx, attrDef->getTypeID(), attrDef, params); 262 } 263 264 DynamicAttr 265 DynamicAttr::getChecked(function_ref<InFlightDiagnostic()> emitError, 266 DynamicAttrDefinition *attrDef, 267 ArrayRef<Attribute> params) { 268 if (failed(attrDef->verify(emitError, params))) 269 return {}; 270 return get(attrDef, params); 271 } 272 273 DynamicAttrDefinition *DynamicAttr::getAttrDef() { return getImpl()->attrDef; } 274 275 ArrayRef<Attribute> DynamicAttr::getParams() { return getImpl()->params; } 276 277 bool DynamicAttr::classof(Attribute attr) { 278 return attr.hasTrait<IsDynamicAttrTrait>(); 279 } 280 281 ParseResult DynamicAttr::parse(AsmParser &parser, 282 DynamicAttrDefinition *attrDef, 283 DynamicAttr &parsedAttr) { 284 SmallVector<Attribute> params; 285 if (failed(attrDef->parser(parser, params))) 286 return failure(); 287 parsedAttr = parser.getChecked<DynamicAttr>(attrDef, params); 288 if (!parsedAttr) 289 return failure(); 290 return success(); 291 } 292 293 void DynamicAttr::print(AsmPrinter &printer) { 294 printer << getAttrDef()->getName(); 295 getAttrDef()->printer(printer, getParams()); 296 } 297 298 //===----------------------------------------------------------------------===// 299 // Dynamic operation 300 //===----------------------------------------------------------------------===// 301 302 DynamicOpDefinition::DynamicOpDefinition( 303 StringRef name, ExtensibleDialect *dialect, 304 OperationName::VerifyInvariantsFn &&verifyFn, 305 OperationName::VerifyRegionInvariantsFn &&verifyRegionFn, 306 OperationName::ParseAssemblyFn &&parseFn, 307 OperationName::PrintAssemblyFn &&printFn, 308 OperationName::FoldHookFn &&foldHookFn, 309 OperationName::GetCanonicalizationPatternsFn 310 &&getCanonicalizationPatternsFn) 311 : typeID(dialect->allocateTypeID()), 312 name((dialect->getNamespace() + "." + name).str()), dialect(dialect), 313 verifyFn(std::move(verifyFn)), verifyRegionFn(std::move(verifyRegionFn)), 314 parseFn(std::move(parseFn)), printFn(std::move(printFn)), 315 foldHookFn(std::move(foldHookFn)), 316 getCanonicalizationPatternsFn(std::move(getCanonicalizationPatternsFn)) { 317 assert(!name.contains('.') && 318 "name should not be prefixed by the dialect name"); 319 } 320 321 std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get( 322 StringRef name, ExtensibleDialect *dialect, 323 OperationName::VerifyInvariantsFn &&verifyFn, 324 OperationName::VerifyRegionInvariantsFn &&verifyRegionFn) { 325 auto parseFn = [](OpAsmParser &parser, OperationState &result) { 326 return parser.emitError( 327 parser.getCurrentLocation(), 328 "dynamic operation do not define any parser function"); 329 }; 330 331 auto printFn = [](Operation *op, OpAsmPrinter &printer, StringRef) { 332 printer.printGenericOp(op); 333 }; 334 335 return DynamicOpDefinition::get(name, dialect, std::move(verifyFn), 336 std::move(verifyRegionFn), std::move(parseFn), 337 std::move(printFn)); 338 } 339 340 std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get( 341 StringRef name, ExtensibleDialect *dialect, 342 OperationName::VerifyInvariantsFn &&verifyFn, 343 OperationName::VerifyRegionInvariantsFn &&verifyRegionFn, 344 OperationName::ParseAssemblyFn &&parseFn, 345 OperationName::PrintAssemblyFn &&printFn) { 346 auto foldHookFn = [](Operation *op, ArrayRef<Attribute> operands, 347 SmallVectorImpl<OpFoldResult> &results) { 348 return failure(); 349 }; 350 351 auto getCanonicalizationPatternsFn = [](RewritePatternSet &, MLIRContext *) { 352 }; 353 354 return DynamicOpDefinition::get(name, dialect, std::move(verifyFn), 355 std::move(verifyRegionFn), std::move(parseFn), 356 std::move(printFn), std::move(foldHookFn), 357 std::move(getCanonicalizationPatternsFn)); 358 } 359 360 std::unique_ptr<DynamicOpDefinition> 361 DynamicOpDefinition::get(StringRef name, ExtensibleDialect *dialect, 362 OperationName::VerifyInvariantsFn &&verifyFn, 363 OperationName::VerifyInvariantsFn &&verifyRegionFn, 364 OperationName::ParseAssemblyFn &&parseFn, 365 OperationName::PrintAssemblyFn &&printFn, 366 OperationName::FoldHookFn &&foldHookFn, 367 OperationName::GetCanonicalizationPatternsFn 368 &&getCanonicalizationPatternsFn) { 369 return std::unique_ptr<DynamicOpDefinition>(new DynamicOpDefinition( 370 name, dialect, std::move(verifyFn), std::move(verifyRegionFn), 371 std::move(parseFn), std::move(printFn), std::move(foldHookFn), 372 std::move(getCanonicalizationPatternsFn))); 373 } 374 375 //===----------------------------------------------------------------------===// 376 // Extensible dialect 377 //===----------------------------------------------------------------------===// 378 379 namespace { 380 /// Interface that can only be implemented by extensible dialects. 381 /// The interface is used to check if a dialect is extensible or not. 382 class IsExtensibleDialect : public DialectInterface::Base<IsExtensibleDialect> { 383 public: 384 IsExtensibleDialect(Dialect *dialect) : Base(dialect) {} 385 }; 386 } // namespace 387 388 ExtensibleDialect::ExtensibleDialect(StringRef name, MLIRContext *ctx, 389 TypeID typeID) 390 : Dialect(name, ctx, typeID) { 391 addInterfaces<IsExtensibleDialect>(); 392 } 393 394 void ExtensibleDialect::registerDynamicType( 395 std::unique_ptr<DynamicTypeDefinition> &&type) { 396 DynamicTypeDefinition *typePtr = type.get(); 397 TypeID typeID = type->getTypeID(); 398 StringRef name = type->getName(); 399 ExtensibleDialect *dialect = type->getDialect(); 400 401 assert(dialect == this && 402 "trying to register a dynamic type in the wrong dialect"); 403 404 // If a type with the same name is already defined, fail. 405 auto registered = dynTypes.try_emplace(typeID, std::move(type)).second; 406 (void)registered; 407 assert(registered && "type TypeID was not unique"); 408 409 registered = nameToDynTypes.insert({name, typePtr}).second; 410 (void)registered; 411 assert(registered && 412 "Trying to create a new dynamic type with an existing name"); 413 414 auto abstractType = 415 AbstractType::get(*dialect, DynamicAttr::getInterfaceMap(), 416 DynamicType::getHasTraitFn(), typeID); 417 418 /// Add the type to the dialect and the type uniquer. 419 addType(typeID, std::move(abstractType)); 420 typePtr->registerInTypeUniquer(); 421 } 422 423 void ExtensibleDialect::registerDynamicAttr( 424 std::unique_ptr<DynamicAttrDefinition> &&attr) { 425 auto *attrPtr = attr.get(); 426 auto typeID = attr->getTypeID(); 427 auto name = attr->getName(); 428 auto *dialect = attr->getDialect(); 429 430 assert(dialect == this && 431 "trying to register a dynamic attribute in the wrong dialect"); 432 433 // If an attribute with the same name is already defined, fail. 434 auto registered = dynAttrs.try_emplace(typeID, std::move(attr)).second; 435 (void)registered; 436 assert(registered && "attribute TypeID was not unique"); 437 438 registered = nameToDynAttrs.insert({name, attrPtr}).second; 439 (void)registered; 440 assert(registered && 441 "Trying to create a new dynamic attribute with an existing name"); 442 443 auto abstractAttr = 444 AbstractAttribute::get(*dialect, DynamicAttr::getInterfaceMap(), 445 DynamicAttr::getHasTraitFn(), typeID); 446 447 /// Add the type to the dialect and the type uniquer. 448 addAttribute(typeID, std::move(abstractAttr)); 449 attrPtr->registerInAttrUniquer(); 450 } 451 452 void ExtensibleDialect::registerDynamicOp( 453 std::unique_ptr<DynamicOpDefinition> &&op) { 454 assert(op->dialect == this && 455 "trying to register a dynamic op in the wrong dialect"); 456 auto hasTraitFn = [](TypeID traitId) { return false; }; 457 458 RegisteredOperationName::insert( 459 op->name, *op->dialect, op->typeID, std::move(op->parseFn), 460 std::move(op->printFn), std::move(op->verifyFn), 461 std::move(op->verifyRegionFn), std::move(op->foldHookFn), 462 std::move(op->getCanonicalizationPatternsFn), 463 detail::InterfaceMap::get<>(), std::move(hasTraitFn), {}); 464 } 465 466 bool ExtensibleDialect::classof(const Dialect *dialect) { 467 return const_cast<Dialect *>(dialect) 468 ->getRegisteredInterface<IsExtensibleDialect>(); 469 } 470 471 OptionalParseResult ExtensibleDialect::parseOptionalDynamicType( 472 StringRef typeName, AsmParser &parser, Type &resultType) const { 473 DynamicTypeDefinition *typeDef = lookupTypeDefinition(typeName); 474 if (!typeDef) 475 return llvm::None; 476 477 DynamicType dynType; 478 if (DynamicType::parse(parser, typeDef, dynType)) 479 return failure(); 480 resultType = dynType; 481 return success(); 482 } 483 484 LogicalResult ExtensibleDialect::printIfDynamicType(Type type, 485 AsmPrinter &printer) { 486 if (auto dynType = type.dyn_cast<DynamicType>()) { 487 dynType.print(printer); 488 return success(); 489 } 490 return failure(); 491 } 492 493 OptionalParseResult ExtensibleDialect::parseOptionalDynamicAttr( 494 StringRef attrName, AsmParser &parser, Attribute &resultAttr) const { 495 DynamicAttrDefinition *attrDef = lookupAttrDefinition(attrName); 496 if (!attrDef) 497 return llvm::None; 498 499 DynamicAttr dynAttr; 500 if (DynamicAttr::parse(parser, attrDef, dynAttr)) 501 return failure(); 502 resultAttr = dynAttr; 503 return success(); 504 } 505 506 LogicalResult ExtensibleDialect::printIfDynamicAttr(Attribute attribute, 507 AsmPrinter &printer) { 508 if (auto dynAttr = attribute.dyn_cast<DynamicAttr>()) { 509 dynAttr.print(printer); 510 return success(); 511 } 512 return failure(); 513 } 514