1 //===- ExtensibleDialect.cpp - Extensible dialect ---------------*- C++ -*-===// 2 // 3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/ExtensibleDialect.h" 10 #include "mlir/IR/AttributeSupport.h" 11 #include "mlir/IR/DialectImplementation.h" 12 #include "mlir/IR/OperationSupport.h" 13 #include "mlir/IR/StorageUniquerSupport.h" 14 #include "mlir/Support/LogicalResult.h" 15 16 using namespace mlir; 17 18 //===----------------------------------------------------------------------===// 19 // Dynamic types and attributes shared functions 20 //===----------------------------------------------------------------------===// 21 22 /// Default parser for dynamic attribute or type parameters. 23 /// Parse in the format '(<>)?' or '<attr (,attr)*>'. 24 static LogicalResult 25 typeOrAttrParser(AsmParser &parser, SmallVectorImpl<Attribute> &parsedParams) { 26 // No parameters 27 if (parser.parseOptionalLess() || !parser.parseOptionalGreater()) 28 return success(); 29 30 Attribute attr; 31 if (parser.parseAttribute(attr)) 32 return failure(); 33 parsedParams.push_back(attr); 34 35 while (parser.parseOptionalGreater()) { 36 Attribute attr; 37 if (parser.parseComma() || parser.parseAttribute(attr)) 38 return failure(); 39 parsedParams.push_back(attr); 40 } 41 42 return success(); 43 } 44 45 /// Default printer for dynamic attribute or type parameters. 46 /// Print in the format '(<>)?' or '<attr (,attr)*>'. 47 static void typeOrAttrPrinter(AsmPrinter &printer, ArrayRef<Attribute> params) { 48 if (params.empty()) 49 return; 50 51 printer << "<"; 52 interleaveComma(params, printer.getStream()); 53 printer << ">"; 54 } 55 56 //===----------------------------------------------------------------------===// 57 // Dynamic type 58 //===----------------------------------------------------------------------===// 59 60 std::unique_ptr<DynamicTypeDefinition> 61 DynamicTypeDefinition::get(StringRef name, ExtensibleDialect *dialect, 62 VerifierFn &&verifier) { 63 return DynamicTypeDefinition::get(name, dialect, std::move(verifier), 64 typeOrAttrParser, typeOrAttrPrinter); 65 } 66 67 std::unique_ptr<DynamicTypeDefinition> 68 DynamicTypeDefinition::get(StringRef name, ExtensibleDialect *dialect, 69 VerifierFn &&verifier, ParserFn &&parser, 70 PrinterFn &&printer) { 71 return std::unique_ptr<DynamicTypeDefinition>( 72 new DynamicTypeDefinition(name, dialect, std::move(verifier), 73 std::move(parser), std::move(printer))); 74 } 75 76 DynamicTypeDefinition::DynamicTypeDefinition(StringRef nameRef, 77 ExtensibleDialect *dialect, 78 VerifierFn &&verifier, 79 ParserFn &&parser, 80 PrinterFn &&printer) 81 : name(nameRef), dialect(dialect), verifier(std::move(verifier)), 82 parser(std::move(parser)), printer(std::move(printer)), 83 ctx(dialect->getContext()) {} 84 85 DynamicTypeDefinition::DynamicTypeDefinition(ExtensibleDialect *dialect, 86 StringRef nameRef) 87 : name(nameRef), dialect(dialect), ctx(dialect->getContext()) {} 88 89 void DynamicTypeDefinition::registerInTypeUniquer() { 90 detail::TypeUniquer::registerType<DynamicType>(&getContext(), getTypeID()); 91 } 92 93 namespace mlir { 94 namespace detail { 95 /// Storage of DynamicType. 96 /// Contains a pointer to the type definition and type parameters. 97 struct DynamicTypeStorage : public TypeStorage { 98 99 using KeyTy = std::pair<DynamicTypeDefinition *, ArrayRef<Attribute>>; 100 101 explicit DynamicTypeStorage(DynamicTypeDefinition *typeDef, 102 ArrayRef<Attribute> params) 103 : typeDef(typeDef), params(params) {} 104 105 bool operator==(const KeyTy &key) const { 106 return typeDef == key.first && params == key.second; 107 } 108 109 static llvm::hash_code hashKey(const KeyTy &key) { 110 return llvm::hash_value(key); 111 } 112 113 static DynamicTypeStorage *construct(TypeStorageAllocator &alloc, 114 const KeyTy &key) { 115 return new (alloc.allocate<DynamicTypeStorage>()) 116 DynamicTypeStorage(key.first, alloc.copyInto(key.second)); 117 } 118 119 /// Definition of the type. 120 DynamicTypeDefinition *typeDef; 121 122 /// The type parameters. 123 ArrayRef<Attribute> params; 124 }; 125 } // namespace detail 126 } // namespace mlir 127 128 DynamicType DynamicType::get(DynamicTypeDefinition *typeDef, 129 ArrayRef<Attribute> params) { 130 auto &ctx = typeDef->getContext(); 131 auto emitError = detail::getDefaultDiagnosticEmitFn(&ctx); 132 assert(succeeded(typeDef->verify(emitError, params))); 133 return detail::TypeUniquer::getWithTypeID<DynamicType>( 134 &ctx, typeDef->getTypeID(), typeDef, params); 135 } 136 137 DynamicType 138 DynamicType::getChecked(function_ref<InFlightDiagnostic()> emitError, 139 DynamicTypeDefinition *typeDef, 140 ArrayRef<Attribute> params) { 141 if (failed(typeDef->verify(emitError, params))) 142 return {}; 143 auto &ctx = typeDef->getContext(); 144 return detail::TypeUniquer::getWithTypeID<DynamicType>( 145 &ctx, typeDef->getTypeID(), typeDef, params); 146 } 147 148 DynamicTypeDefinition *DynamicType::getTypeDef() { return getImpl()->typeDef; } 149 150 ArrayRef<Attribute> DynamicType::getParams() { return getImpl()->params; } 151 152 bool DynamicType::classof(Type type) { 153 return type.hasTrait<TypeTrait::IsDynamicType>(); 154 } 155 156 ParseResult DynamicType::parse(AsmParser &parser, 157 DynamicTypeDefinition *typeDef, 158 DynamicType &parsedType) { 159 SmallVector<Attribute> params; 160 if (failed(typeDef->parser(parser, params))) 161 return failure(); 162 parsedType = parser.getChecked<DynamicType>(typeDef, params); 163 if (!parsedType) 164 return failure(); 165 return success(); 166 } 167 168 void DynamicType::print(AsmPrinter &printer) { 169 printer << getTypeDef()->getName(); 170 getTypeDef()->printer(printer, getParams()); 171 } 172 173 //===----------------------------------------------------------------------===// 174 // Dynamic attribute 175 //===----------------------------------------------------------------------===// 176 177 std::unique_ptr<DynamicAttrDefinition> 178 DynamicAttrDefinition::get(StringRef name, ExtensibleDialect *dialect, 179 VerifierFn &&verifier) { 180 return DynamicAttrDefinition::get(name, dialect, std::move(verifier), 181 typeOrAttrParser, typeOrAttrPrinter); 182 } 183 184 std::unique_ptr<DynamicAttrDefinition> 185 DynamicAttrDefinition::get(StringRef name, ExtensibleDialect *dialect, 186 VerifierFn &&verifier, ParserFn &&parser, 187 PrinterFn &&printer) { 188 return std::unique_ptr<DynamicAttrDefinition>( 189 new DynamicAttrDefinition(name, dialect, std::move(verifier), 190 std::move(parser), std::move(printer))); 191 } 192 193 DynamicAttrDefinition::DynamicAttrDefinition(StringRef nameRef, 194 ExtensibleDialect *dialect, 195 VerifierFn &&verifier, 196 ParserFn &&parser, 197 PrinterFn &&printer) 198 : name(nameRef), dialect(dialect), verifier(std::move(verifier)), 199 parser(std::move(parser)), printer(std::move(printer)), 200 ctx(dialect->getContext()) {} 201 202 DynamicAttrDefinition::DynamicAttrDefinition(ExtensibleDialect *dialect, 203 StringRef nameRef) 204 : name(nameRef), dialect(dialect), ctx(dialect->getContext()) {} 205 206 void DynamicAttrDefinition::registerInAttrUniquer() { 207 detail::AttributeUniquer::registerAttribute<DynamicAttr>(&getContext(), 208 getTypeID()); 209 } 210 211 namespace mlir { 212 namespace detail { 213 /// Storage of DynamicAttr. 214 /// Contains a pointer to the attribute definition and attribute parameters. 215 struct DynamicAttrStorage : public AttributeStorage { 216 using KeyTy = std::pair<DynamicAttrDefinition *, ArrayRef<Attribute>>; 217 218 explicit DynamicAttrStorage(DynamicAttrDefinition *attrDef, 219 ArrayRef<Attribute> params) 220 : attrDef(attrDef), params(params) {} 221 222 bool operator==(const KeyTy &key) const { 223 return attrDef == key.first && params == key.second; 224 } 225 226 static llvm::hash_code hashKey(const KeyTy &key) { 227 return llvm::hash_value(key); 228 } 229 230 static DynamicAttrStorage *construct(AttributeStorageAllocator &alloc, 231 const KeyTy &key) { 232 return new (alloc.allocate<DynamicAttrStorage>()) 233 DynamicAttrStorage(key.first, alloc.copyInto(key.second)); 234 } 235 236 /// Definition of the type. 237 DynamicAttrDefinition *attrDef; 238 239 /// The type parameters. 240 ArrayRef<Attribute> params; 241 }; 242 } // namespace detail 243 } // namespace mlir 244 245 DynamicAttr DynamicAttr::get(DynamicAttrDefinition *attrDef, 246 ArrayRef<Attribute> params) { 247 auto &ctx = attrDef->getContext(); 248 return detail::AttributeUniquer::getWithTypeID<DynamicAttr>( 249 &ctx, attrDef->getTypeID(), attrDef, params); 250 } 251 252 DynamicAttr 253 DynamicAttr::getChecked(function_ref<InFlightDiagnostic()> emitError, 254 DynamicAttrDefinition *attrDef, 255 ArrayRef<Attribute> params) { 256 if (failed(attrDef->verify(emitError, params))) 257 return {}; 258 return get(attrDef, params); 259 } 260 261 DynamicAttrDefinition *DynamicAttr::getAttrDef() { return getImpl()->attrDef; } 262 263 ArrayRef<Attribute> DynamicAttr::getParams() { return getImpl()->params; } 264 265 bool DynamicAttr::classof(Attribute attr) { 266 return attr.hasTrait<AttributeTrait::IsDynamicAttr>(); 267 } 268 269 ParseResult DynamicAttr::parse(AsmParser &parser, 270 DynamicAttrDefinition *attrDef, 271 DynamicAttr &parsedAttr) { 272 SmallVector<Attribute> params; 273 if (failed(attrDef->parser(parser, params))) 274 return failure(); 275 parsedAttr = parser.getChecked<DynamicAttr>(attrDef, params); 276 if (!parsedAttr) 277 return failure(); 278 return success(); 279 } 280 281 void DynamicAttr::print(AsmPrinter &printer) { 282 printer << getAttrDef()->getName(); 283 getAttrDef()->printer(printer, getParams()); 284 } 285 286 //===----------------------------------------------------------------------===// 287 // Dynamic operation 288 //===----------------------------------------------------------------------===// 289 290 DynamicOpDefinition::DynamicOpDefinition( 291 StringRef name, ExtensibleDialect *dialect, 292 OperationName::VerifyInvariantsFn &&verifyFn, 293 OperationName::VerifyRegionInvariantsFn &&verifyRegionFn, 294 OperationName::ParseAssemblyFn &&parseFn, 295 OperationName::PrintAssemblyFn &&printFn, 296 OperationName::FoldHookFn &&foldHookFn, 297 OperationName::GetCanonicalizationPatternsFn 298 &&getCanonicalizationPatternsFn) 299 : typeID(dialect->allocateTypeID()), 300 name((dialect->getNamespace() + "." + name).str()), dialect(dialect), 301 verifyFn(std::move(verifyFn)), verifyRegionFn(std::move(verifyRegionFn)), 302 parseFn(std::move(parseFn)), printFn(std::move(printFn)), 303 foldHookFn(std::move(foldHookFn)), 304 getCanonicalizationPatternsFn(std::move(getCanonicalizationPatternsFn)) {} 305 306 std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get( 307 StringRef name, ExtensibleDialect *dialect, 308 OperationName::VerifyInvariantsFn &&verifyFn, 309 OperationName::VerifyRegionInvariantsFn &&verifyRegionFn) { 310 auto parseFn = [](OpAsmParser &parser, OperationState &result) { 311 return parser.emitError( 312 parser.getCurrentLocation(), 313 "dynamic operation do not define any parser function"); 314 }; 315 316 auto printFn = [](Operation *op, OpAsmPrinter &printer, StringRef) { 317 printer.printGenericOp(op); 318 }; 319 320 return DynamicOpDefinition::get(name, dialect, std::move(verifyFn), 321 std::move(verifyRegionFn), std::move(parseFn), 322 std::move(printFn)); 323 } 324 325 std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get( 326 StringRef name, ExtensibleDialect *dialect, 327 OperationName::VerifyInvariantsFn &&verifyFn, 328 OperationName::VerifyRegionInvariantsFn &&verifyRegionFn, 329 OperationName::ParseAssemblyFn &&parseFn, 330 OperationName::PrintAssemblyFn &&printFn) { 331 auto foldHookFn = [](Operation *op, ArrayRef<Attribute> operands, 332 SmallVectorImpl<OpFoldResult> &results) { 333 return failure(); 334 }; 335 336 auto getCanonicalizationPatternsFn = [](RewritePatternSet &, MLIRContext *) { 337 }; 338 339 return DynamicOpDefinition::get(name, dialect, std::move(verifyFn), 340 std::move(verifyRegionFn), std::move(parseFn), 341 std::move(printFn), std::move(foldHookFn), 342 std::move(getCanonicalizationPatternsFn)); 343 } 344 345 std::unique_ptr<DynamicOpDefinition> 346 DynamicOpDefinition::get(StringRef name, ExtensibleDialect *dialect, 347 OperationName::VerifyInvariantsFn &&verifyFn, 348 OperationName::VerifyInvariantsFn &&verifyRegionFn, 349 OperationName::ParseAssemblyFn &&parseFn, 350 OperationName::PrintAssemblyFn &&printFn, 351 OperationName::FoldHookFn &&foldHookFn, 352 OperationName::GetCanonicalizationPatternsFn 353 &&getCanonicalizationPatternsFn) { 354 return std::unique_ptr<DynamicOpDefinition>(new DynamicOpDefinition( 355 name, dialect, std::move(verifyFn), std::move(verifyRegionFn), 356 std::move(parseFn), std::move(printFn), std::move(foldHookFn), 357 std::move(getCanonicalizationPatternsFn))); 358 } 359 360 //===----------------------------------------------------------------------===// 361 // Extensible dialect 362 //===----------------------------------------------------------------------===// 363 364 namespace { 365 /// Interface that can only be implemented by extensible dialects. 366 /// The interface is used to check if a dialect is extensible or not. 367 class IsExtensibleDialect : public DialectInterface::Base<IsExtensibleDialect> { 368 public: 369 IsExtensibleDialect(Dialect *dialect) : Base(dialect) {} 370 371 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(IsExtensibleDialect) 372 }; 373 } // namespace 374 375 ExtensibleDialect::ExtensibleDialect(StringRef name, MLIRContext *ctx, 376 TypeID typeID) 377 : Dialect(name, ctx, typeID) { 378 addInterfaces<IsExtensibleDialect>(); 379 } 380 381 void ExtensibleDialect::registerDynamicType( 382 std::unique_ptr<DynamicTypeDefinition> &&type) { 383 DynamicTypeDefinition *typePtr = type.get(); 384 TypeID typeID = type->getTypeID(); 385 StringRef name = type->getName(); 386 ExtensibleDialect *dialect = type->getDialect(); 387 388 assert(dialect == this && 389 "trying to register a dynamic type in the wrong dialect"); 390 391 // If a type with the same name is already defined, fail. 392 auto registered = dynTypes.try_emplace(typeID, std::move(type)).second; 393 (void)registered; 394 assert(registered && "type TypeID was not unique"); 395 396 registered = nameToDynTypes.insert({name, typePtr}).second; 397 (void)registered; 398 assert(registered && 399 "Trying to create a new dynamic type with an existing name"); 400 401 auto abstractType = 402 AbstractType::get(*dialect, DynamicAttr::getInterfaceMap(), 403 DynamicType::getHasTraitFn(), typeID); 404 405 /// Add the type to the dialect and the type uniquer. 406 addType(typeID, std::move(abstractType)); 407 typePtr->registerInTypeUniquer(); 408 } 409 410 void ExtensibleDialect::registerDynamicAttr( 411 std::unique_ptr<DynamicAttrDefinition> &&attr) { 412 auto *attrPtr = attr.get(); 413 auto typeID = attr->getTypeID(); 414 auto name = attr->getName(); 415 auto *dialect = attr->getDialect(); 416 417 assert(dialect == this && 418 "trying to register a dynamic attribute in the wrong dialect"); 419 420 // If an attribute with the same name is already defined, fail. 421 auto registered = dynAttrs.try_emplace(typeID, std::move(attr)).second; 422 (void)registered; 423 assert(registered && "attribute TypeID was not unique"); 424 425 registered = nameToDynAttrs.insert({name, attrPtr}).second; 426 (void)registered; 427 assert(registered && 428 "Trying to create a new dynamic attribute with an existing name"); 429 430 auto abstractAttr = 431 AbstractAttribute::get(*dialect, DynamicAttr::getInterfaceMap(), 432 DynamicAttr::getHasTraitFn(), typeID); 433 434 /// Add the type to the dialect and the type uniquer. 435 addAttribute(typeID, std::move(abstractAttr)); 436 attrPtr->registerInAttrUniquer(); 437 } 438 439 void ExtensibleDialect::registerDynamicOp( 440 std::unique_ptr<DynamicOpDefinition> &&op) { 441 assert(op->dialect == this && 442 "trying to register a dynamic op in the wrong dialect"); 443 auto hasTraitFn = [](TypeID traitId) { return false; }; 444 445 RegisteredOperationName::insert( 446 op->name, *op->dialect, op->typeID, std::move(op->parseFn), 447 std::move(op->printFn), std::move(op->verifyFn), 448 std::move(op->verifyRegionFn), std::move(op->foldHookFn), 449 std::move(op->getCanonicalizationPatternsFn), 450 detail::InterfaceMap::get<>(), std::move(hasTraitFn), {}); 451 } 452 453 bool ExtensibleDialect::classof(const Dialect *dialect) { 454 return const_cast<Dialect *>(dialect) 455 ->getRegisteredInterface<IsExtensibleDialect>(); 456 } 457 458 OptionalParseResult ExtensibleDialect::parseOptionalDynamicType( 459 StringRef typeName, AsmParser &parser, Type &resultType) const { 460 DynamicTypeDefinition *typeDef = lookupTypeDefinition(typeName); 461 if (!typeDef) 462 return llvm::None; 463 464 DynamicType dynType; 465 if (DynamicType::parse(parser, typeDef, dynType)) 466 return failure(); 467 resultType = dynType; 468 return success(); 469 } 470 471 LogicalResult ExtensibleDialect::printIfDynamicType(Type type, 472 AsmPrinter &printer) { 473 if (auto dynType = type.dyn_cast<DynamicType>()) { 474 dynType.print(printer); 475 return success(); 476 } 477 return failure(); 478 } 479 480 OptionalParseResult ExtensibleDialect::parseOptionalDynamicAttr( 481 StringRef attrName, AsmParser &parser, Attribute &resultAttr) const { 482 DynamicAttrDefinition *attrDef = lookupAttrDefinition(attrName); 483 if (!attrDef) 484 return llvm::None; 485 486 DynamicAttr dynAttr; 487 if (DynamicAttr::parse(parser, attrDef, dynAttr)) 488 return failure(); 489 resultAttr = dynAttr; 490 return success(); 491 } 492 493 LogicalResult ExtensibleDialect::printIfDynamicAttr(Attribute attribute, 494 AsmPrinter &printer) { 495 if (auto dynAttr = attribute.dyn_cast<DynamicAttr>()) { 496 dynAttr.print(printer); 497 return success(); 498 } 499 return failure(); 500 } 501