1 //===- SymbolTable.cpp - MLIR Symbol Table Class --------------------------===// 2 // 3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/SymbolTable.h" 10 #include "llvm/ADT/SetVector.h" 11 #include "llvm/ADT/SmallPtrSet.h" 12 #include "llvm/ADT/SmallString.h" 13 #include "llvm/ADT/StringSwitch.h" 14 15 using namespace mlir; 16 17 /// Return true if the given operation is unknown and may potentially define a 18 /// symbol table. 19 static bool isPotentiallyUnknownSymbolTable(Operation *op) { 20 return !op->getDialect() && op->getNumRegions() == 1; 21 } 22 23 /// Returns the nearest symbol table from a given operation `from`. Returns 24 /// nullptr if no valid parent symbol table could be found. 25 static Operation *getNearestSymbolTable(Operation *from) { 26 assert(from && "expected valid operation"); 27 if (isPotentiallyUnknownSymbolTable(from)) 28 return nullptr; 29 30 while (!from->hasTrait<OpTrait::SymbolTable>()) { 31 from = from->getParentOp(); 32 33 // Check that this is a valid op and isn't an unknown symbol table. 34 if (!from || isPotentiallyUnknownSymbolTable(from)) 35 return nullptr; 36 } 37 return from; 38 } 39 40 /// Returns the string name of the given symbol, or None if this is not a 41 /// symbol. 42 static Optional<StringRef> getNameIfSymbol(Operation *symbol) { 43 auto nameAttr = 44 symbol->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()); 45 return nameAttr ? nameAttr.getValue() : Optional<StringRef>(); 46 } 47 48 /// Computes the nested symbol reference attribute for the symbol 'symbolName' 49 /// that are usable within the symbol table operations from 'symbol' as far up 50 /// to the given operation 'within', where 'within' is an ancestor of 'symbol'. 51 /// Returns success if all references up to 'within' could be computed. 52 static LogicalResult 53 collectValidReferencesFor(Operation *symbol, StringRef symbolName, 54 Operation *within, 55 SmallVectorImpl<SymbolRefAttr> &results) { 56 assert(within->isAncestor(symbol) && "expected 'within' to be an ancestor"); 57 MLIRContext *ctx = symbol->getContext(); 58 59 auto leafRef = FlatSymbolRefAttr::get(symbolName, ctx); 60 results.push_back(leafRef); 61 62 // Early exit for when 'within' is the parent of 'symbol'. 63 Operation *symbolTableOp = symbol->getParentOp(); 64 if (within == symbolTableOp) 65 return success(); 66 67 // Collect references until 'symbolTableOp' reaches 'within'. 68 SmallVector<FlatSymbolRefAttr, 1> nestedRefs(1, leafRef); 69 do { 70 // Each parent of 'symbol' should define a symbol table. 71 if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>()) 72 return failure(); 73 // Each parent of 'symbol' should also be a symbol. 74 Optional<StringRef> symbolTableName = getNameIfSymbol(symbolTableOp); 75 if (!symbolTableName) 76 return failure(); 77 results.push_back(SymbolRefAttr::get(*symbolTableName, nestedRefs, ctx)); 78 79 symbolTableOp = symbolTableOp->getParentOp(); 80 if (symbolTableOp == within) 81 break; 82 nestedRefs.insert(nestedRefs.begin(), 83 FlatSymbolRefAttr::get(*symbolTableName, ctx)); 84 } while (true); 85 return success(); 86 } 87 88 //===----------------------------------------------------------------------===// 89 // SymbolTable 90 //===----------------------------------------------------------------------===// 91 92 /// Build a symbol table with the symbols within the given operation. 93 SymbolTable::SymbolTable(Operation *symbolTableOp) 94 : symbolTableOp(symbolTableOp) { 95 assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>() && 96 "expected operation to have SymbolTable trait"); 97 assert(symbolTableOp->getNumRegions() == 1 && 98 "expected operation to have a single region"); 99 assert(has_single_element(symbolTableOp->getRegion(0)) && 100 "expected operation to have a single block"); 101 102 for (auto &op : symbolTableOp->getRegion(0).front()) { 103 Optional<StringRef> name = getNameIfSymbol(&op); 104 if (!name) 105 continue; 106 107 auto inserted = symbolTable.insert({*name, &op}); 108 (void)inserted; 109 assert(inserted.second && 110 "expected region to contain uniquely named symbol operations"); 111 } 112 } 113 114 /// Look up a symbol with the specified name, returning null if no such name 115 /// exists. Names never include the @ on them. 116 Operation *SymbolTable::lookup(StringRef name) const { 117 return symbolTable.lookup(name); 118 } 119 120 /// Erase the given symbol from the table. 121 void SymbolTable::erase(Operation *symbol) { 122 Optional<StringRef> name = getNameIfSymbol(symbol); 123 assert(name && "expected valid 'name' attribute"); 124 assert(symbol->getParentOp() == symbolTableOp && 125 "expected this operation to be inside of the operation with this " 126 "SymbolTable"); 127 128 auto it = symbolTable.find(*name); 129 if (it != symbolTable.end() && it->second == symbol) { 130 symbolTable.erase(it); 131 symbol->erase(); 132 } 133 } 134 135 /// Insert a new symbol into the table and associated operation, and rename it 136 /// as necessary to avoid collisions. 137 void SymbolTable::insert(Operation *symbol, Block::iterator insertPt) { 138 auto &body = symbolTableOp->getRegion(0).front(); 139 if (insertPt == Block::iterator() || insertPt == body.end()) 140 insertPt = Block::iterator(body.getTerminator()); 141 142 assert(insertPt->getParentOp() == symbolTableOp && 143 "expected insertPt to be in the associated module operation"); 144 145 body.getOperations().insert(insertPt, symbol); 146 147 // Add this symbol to the symbol table, uniquing the name if a conflict is 148 // detected. 149 StringRef name = getSymbolName(symbol); 150 if (symbolTable.insert({name, symbol}).second) 151 return; 152 // If a conflict was detected, then the symbol will not have been added to 153 // the symbol table. Try suffixes until we get to a unique name that works. 154 SmallString<128> nameBuffer(name); 155 unsigned originalLength = nameBuffer.size(); 156 157 // Iteratively try suffixes until we find one that isn't used. 158 do { 159 nameBuffer.resize(originalLength); 160 nameBuffer += '_'; 161 nameBuffer += std::to_string(uniquingCounter++); 162 } while (!symbolTable.insert({nameBuffer, symbol}).second); 163 setSymbolName(symbol, nameBuffer); 164 } 165 166 /// Returns true if the given operation defines a symbol. 167 bool SymbolTable::isSymbol(Operation *op) { 168 return op->hasTrait<OpTrait::Symbol>() || getNameIfSymbol(op).hasValue(); 169 } 170 171 /// Returns the name of the given symbol operation. 172 StringRef SymbolTable::getSymbolName(Operation *symbol) { 173 Optional<StringRef> name = getNameIfSymbol(symbol); 174 assert(name && "expected valid symbol name"); 175 return *name; 176 } 177 /// Sets the name of the given symbol operation. 178 void SymbolTable::setSymbolName(Operation *symbol, StringRef name) { 179 symbol->setAttr(getSymbolAttrName(), 180 StringAttr::get(name, symbol->getContext())); 181 } 182 183 /// Returns the visibility of the given symbol operation. 184 SymbolTable::Visibility SymbolTable::getSymbolVisibility(Operation *symbol) { 185 // If the attribute doesn't exist, assume public. 186 StringAttr vis = symbol->getAttrOfType<StringAttr>(getVisibilityAttrName()); 187 if (!vis) 188 return Visibility::Public; 189 190 // Otherwise, switch on the string value. 191 return llvm::StringSwitch<Visibility>(vis.getValue()) 192 .Case("private", Visibility::Private) 193 .Case("nested", Visibility::Nested) 194 .Case("public", Visibility::Public); 195 } 196 /// Sets the visibility of the given symbol operation. 197 void SymbolTable::setSymbolVisibility(Operation *symbol, Visibility vis) { 198 MLIRContext *ctx = symbol->getContext(); 199 200 // If the visibility is public, just drop the attribute as this is the 201 // default. 202 if (vis == Visibility::Public) { 203 symbol->removeAttr(Identifier::get(getVisibilityAttrName(), ctx)); 204 return; 205 } 206 207 // Otherwise, update the attribute. 208 assert((vis == Visibility::Private || vis == Visibility::Nested) && 209 "unknown symbol visibility kind"); 210 211 StringRef visName = vis == Visibility::Private ? "private" : "nested"; 212 symbol->setAttr(getVisibilityAttrName(), StringAttr::get(visName, ctx)); 213 } 214 215 /// Returns the operation registered with the given symbol name with the 216 /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation 217 /// with the 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol 218 /// was found. 219 Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp, 220 StringRef symbol) { 221 assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>()); 222 223 // Look for a symbol with the given name. 224 for (auto &block : symbolTableOp->getRegion(0)) { 225 for (auto &op : block) 226 if (getNameIfSymbol(&op) == symbol) 227 return &op; 228 } 229 return nullptr; 230 } 231 Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp, 232 SymbolRefAttr symbol) { 233 assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>()); 234 235 // Lookup the root reference for this symbol. 236 symbolTableOp = lookupSymbolIn(symbolTableOp, symbol.getRootReference()); 237 if (!symbolTableOp) 238 return nullptr; 239 240 // If there are no nested references, just return the root symbol directly. 241 ArrayRef<FlatSymbolRefAttr> nestedRefs = symbol.getNestedReferences(); 242 if (nestedRefs.empty()) 243 return symbolTableOp; 244 245 // Verify that the root is also a symbol table. 246 if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>()) 247 return nullptr; 248 249 // Otherwise, lookup each of the nested non-leaf references and ensure that 250 // each corresponds to a valid symbol table. 251 for (FlatSymbolRefAttr ref : nestedRefs.drop_back()) { 252 symbolTableOp = lookupSymbolIn(symbolTableOp, ref.getValue()); 253 if (!symbolTableOp || !symbolTableOp->hasTrait<OpTrait::SymbolTable>()) 254 return nullptr; 255 } 256 return lookupSymbolIn(symbolTableOp, symbol.getLeafReference()); 257 } 258 259 /// Returns the operation registered with the given symbol name within the 260 /// closes parent operation with the 'OpTrait::SymbolTable' trait. Returns 261 /// nullptr if no valid symbol was found. 262 Operation *SymbolTable::lookupNearestSymbolFrom(Operation *from, 263 StringRef symbol) { 264 Operation *symbolTableOp = getNearestSymbolTable(from); 265 return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr; 266 } 267 Operation *SymbolTable::lookupNearestSymbolFrom(Operation *from, 268 SymbolRefAttr symbol) { 269 Operation *symbolTableOp = getNearestSymbolTable(from); 270 return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr; 271 } 272 273 //===----------------------------------------------------------------------===// 274 // SymbolTable Trait Types 275 //===----------------------------------------------------------------------===// 276 277 LogicalResult OpTrait::impl::verifySymbolTable(Operation *op) { 278 if (op->getNumRegions() != 1) 279 return op->emitOpError() 280 << "Operations with a 'SymbolTable' must have exactly one region"; 281 if (!has_single_element(op->getRegion(0))) 282 return op->emitOpError() 283 << "Operations with a 'SymbolTable' must have exactly one block"; 284 285 // Check that all symbols are uniquely named within child regions. 286 DenseMap<Attribute, Location> nameToOrigLoc; 287 for (auto &block : op->getRegion(0)) { 288 for (auto &op : block) { 289 // Check for a symbol name attribute. 290 auto nameAttr = 291 op.getAttrOfType<StringAttr>(mlir::SymbolTable::getSymbolAttrName()); 292 if (!nameAttr) 293 continue; 294 295 // Try to insert this symbol into the table. 296 auto it = nameToOrigLoc.try_emplace(nameAttr, op.getLoc()); 297 if (!it.second) 298 return op.emitError() 299 .append("redefinition of symbol named '", nameAttr.getValue(), "'") 300 .attachNote(it.first->second) 301 .append("see existing symbol definition here"); 302 } 303 } 304 return success(); 305 } 306 307 LogicalResult OpTrait::impl::verifySymbol(Operation *op) { 308 // Verify the name attribute. 309 if (!op->getAttrOfType<StringAttr>(mlir::SymbolTable::getSymbolAttrName())) 310 return op->emitOpError() << "requires string attribute '" 311 << mlir::SymbolTable::getSymbolAttrName() << "'"; 312 313 // Verify the visibility attribute. 314 if (Attribute vis = op->getAttr(mlir::SymbolTable::getVisibilityAttrName())) { 315 StringAttr visStrAttr = vis.dyn_cast<StringAttr>(); 316 if (!visStrAttr) 317 return op->emitOpError() << "requires visibility attribute '" 318 << mlir::SymbolTable::getVisibilityAttrName() 319 << "' to be a string attribute, but got " << vis; 320 321 if (!llvm::is_contained(ArrayRef<StringRef>{"public", "private", "nested"}, 322 visStrAttr.getValue())) 323 return op->emitOpError() 324 << "visibility expected to be one of [\"public\", \"private\", " 325 "\"nested\"], but got " 326 << visStrAttr; 327 } 328 return success(); 329 } 330 331 //===----------------------------------------------------------------------===// 332 // Symbol Use Lists 333 //===----------------------------------------------------------------------===// 334 335 /// Walk all of the symbol references within the given operation, invoking the 336 /// provided callback for each found use. The callbacks takes as arguments: the 337 /// use of the symbol, and the nested access chain to the attribute within the 338 /// operation dictionary. An access chain is a set of indices into nested 339 /// container attributes. For example, a symbol use in an attribute dictionary 340 /// that looks like the following: 341 /// 342 /// {use = [{other_attr, @symbol}]} 343 /// 344 /// May have the following access chain: 345 /// 346 /// [0, 0, 1] 347 /// 348 static WalkResult walkSymbolRefs( 349 Operation *op, 350 function_ref<WalkResult(SymbolTable::SymbolUse, ArrayRef<int>)> callback) { 351 // Check to see if the operation has any attributes. 352 DictionaryAttr attrDict = op->getAttrList().getDictionary(); 353 if (!attrDict) 354 return WalkResult::advance(); 355 356 // A worklist of a container attribute and the current index into the held 357 // attribute list. 358 SmallVector<Attribute, 1> attrWorklist(1, attrDict); 359 SmallVector<int, 1> curAccessChain(1, /*Value=*/-1); 360 361 // Process the symbol references within the given nested attribute range. 362 auto processAttrs = [&](int &index, auto attrRange) -> WalkResult { 363 for (Attribute attr : llvm::drop_begin(attrRange, index)) { 364 /// Check for a nested container attribute, these will also need to be 365 /// walked. 366 if (attr.isa<ArrayAttr>() || attr.isa<DictionaryAttr>()) { 367 attrWorklist.push_back(attr); 368 curAccessChain.push_back(-1); 369 return WalkResult::advance(); 370 } 371 372 // Invoke the provided callback if we find a symbol use and check for a 373 // requested interrupt. 374 if (auto symbolRef = attr.dyn_cast<SymbolRefAttr>()) 375 if (callback({op, symbolRef}, curAccessChain).wasInterrupted()) 376 return WalkResult::interrupt(); 377 378 // Make sure to keep the index counter in sync. 379 ++index; 380 } 381 382 // Pop this container attribute from the worklist. 383 attrWorklist.pop_back(); 384 curAccessChain.pop_back(); 385 return WalkResult::advance(); 386 }; 387 388 WalkResult result = WalkResult::advance(); 389 do { 390 Attribute attr = attrWorklist.back(); 391 int &index = curAccessChain.back(); 392 ++index; 393 394 // Process the given attribute, which is guaranteed to be a container. 395 if (auto dict = attr.dyn_cast<DictionaryAttr>()) 396 result = processAttrs(index, make_second_range(dict.getValue())); 397 else 398 result = processAttrs(index, attr.cast<ArrayAttr>().getValue()); 399 } while (!attrWorklist.empty() && !result.wasInterrupted()); 400 return result; 401 } 402 403 /// Walk all of the uses, for any symbol, that are nested within the given 404 /// operation 'from', invoking the provided callback for each. This does not 405 /// traverse into any nested symbol tables, and will also only return uses on 406 /// 'from' if it does not also define a symbol table. 407 static Optional<WalkResult> walkSymbolUses( 408 Operation *from, 409 function_ref<WalkResult(SymbolTable::SymbolUse, ArrayRef<int>)> callback) { 410 // If from is not a symbol table, check for uses. A symbol table defines a new 411 // scope, so we can't walk the attributes from the symbol table op. 412 if (!from->hasTrait<OpTrait::SymbolTable>()) { 413 if (walkSymbolRefs(from, callback).wasInterrupted()) 414 return WalkResult::interrupt(); 415 } 416 417 SmallVector<Region *, 1> worklist; 418 worklist.reserve(from->getNumRegions()); 419 for (Region ®ion : from->getRegions()) 420 worklist.push_back(®ion); 421 422 while (!worklist.empty()) { 423 Region *region = worklist.pop_back_val(); 424 for (Block &block : *region) { 425 for (Operation &op : block) { 426 if (walkSymbolRefs(&op, callback).wasInterrupted()) 427 return WalkResult::interrupt(); 428 429 // If this operation has regions, and it as well as its dialect aren't 430 // registered then conservatively fail. The operation may define a 431 // symbol table, so we can't opaquely know if we should traverse to find 432 // nested uses. 433 if (isPotentiallyUnknownSymbolTable(&op)) 434 return llvm::None; 435 436 // If this op defines a new symbol table scope, we can't traverse. Any 437 // symbol references nested within 'op' are different semantically. 438 if (!op.hasTrait<OpTrait::SymbolTable>()) { 439 for (Region ®ion : op.getRegions()) 440 worklist.push_back(®ion); 441 } 442 } 443 } 444 } 445 return WalkResult::advance(); 446 } 447 448 /// Walks all of the symbol scopes from 'symbol' to (inclusive) 'limit' invoking 449 /// the provided callback at each one with a properly scoped reference to 450 /// 'symbol'. The callback takes as parameters the symbol reference at the 451 /// current scope as well as the top-level operation representing the top of 452 /// that scope. 453 static Optional<WalkResult> walkSymbolScopes( 454 Operation *symbol, Operation *limit, 455 function_ref<Optional<WalkResult>(SymbolRefAttr, Operation *)> callback) { 456 StringRef symbolName = SymbolTable::getSymbolName(symbol); 457 assert(!symbol->hasTrait<OpTrait::SymbolTable>() || symbol != limit); 458 459 // Compute the ancestors of 'limit'. 460 llvm::SetVector<Operation *, SmallVector<Operation *, 4>, 461 SmallPtrSet<Operation *, 4>> 462 limitAncestors; 463 Operation *limitAncestor = limit; 464 do { 465 // Check to see if 'symbol' is an ancestor of 'limit'. 466 if (limitAncestor == symbol) { 467 // Check that the nearest symbol table is 'symbol's parent. SymbolRefAttr 468 // doesn't support parent references. 469 if (getNearestSymbolTable(limit) != symbol->getParentOp()) 470 return WalkResult::advance(); 471 return callback(SymbolRefAttr::get(symbolName, symbol->getContext()), 472 limit); 473 } 474 475 limitAncestors.insert(limitAncestor); 476 } while ((limitAncestor = limitAncestor->getParentOp())); 477 478 // Try to find the first ancestor of 'symbol' that is an ancestor of 'limit'. 479 Operation *commonAncestor = symbol->getParentOp(); 480 do { 481 if (limitAncestors.count(commonAncestor)) 482 break; 483 } while ((commonAncestor = commonAncestor->getParentOp())); 484 assert(commonAncestor && "'limit' and 'symbol' have no common ancestor"); 485 486 // Compute the set of valid nested references for 'symbol' as far up to the 487 // common ancestor as possible. 488 SmallVector<SymbolRefAttr, 2> references; 489 bool collectedAllReferences = succeeded(collectValidReferencesFor( 490 symbol, symbolName, commonAncestor, references)); 491 492 // Handle the case where the common ancestor is 'limit'. 493 if (commonAncestor == limit) { 494 // Walk each of the ancestors of 'symbol', calling the compute function for 495 // each one. 496 Operation *limitIt = symbol->getParentOp(); 497 for (size_t i = 0, e = references.size(); i != e; 498 ++i, limitIt = limitIt->getParentOp()) { 499 Optional<WalkResult> callbackResult = callback(references[i], limitIt); 500 if (callbackResult != WalkResult::advance()) 501 return callbackResult; 502 } 503 return WalkResult::advance(); 504 } 505 506 // Otherwise, we just need the symbol reference for 'symbol' that will be 507 // used within 'limit'. This is the last reference in the list we computed 508 // above if we were able to collect all references. 509 if (!collectedAllReferences) 510 return WalkResult::advance(); 511 return callback(references.back(), limit); 512 } 513 514 /// Walk the symbol scopes defined by 'limit' invoking the provided callback. 515 static Optional<WalkResult> walkSymbolScopes( 516 StringRef symbol, Operation *limit, 517 function_ref<Optional<WalkResult>(SymbolRefAttr, Operation *)> callback) { 518 return callback(SymbolRefAttr::get(symbol, limit->getContext()), limit); 519 } 520 521 /// Returns true if the given reference 'SubRef' is a sub reference of the 522 /// reference 'ref', i.e. 'ref' is a further qualified reference. 523 static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) { 524 if (ref == subRef) 525 return true; 526 527 // If the references are not pointer equal, check to see if `subRef` is a 528 // prefix of `ref`. 529 if (ref.isa<FlatSymbolRefAttr>() || 530 ref.getRootReference() != subRef.getRootReference()) 531 return false; 532 533 auto refLeafs = ref.getNestedReferences(); 534 auto subRefLeafs = subRef.getNestedReferences(); 535 return subRefLeafs.size() < refLeafs.size() && 536 subRefLeafs == refLeafs.take_front(subRefLeafs.size()); 537 } 538 539 //===----------------------------------------------------------------------===// 540 // SymbolTable::getSymbolUses 541 542 /// Get an iterator range for all of the uses, for any symbol, that are nested 543 /// within the given operation 'from'. This does not traverse into any nested 544 /// symbol tables, and will also only return uses on 'from' if it does not 545 /// also define a symbol table. This is because we treat the region as the 546 /// boundary of the symbol table, and not the op itself. This function returns 547 /// None if there are any unknown operations that may potentially be symbol 548 /// tables. 549 auto SymbolTable::getSymbolUses(Operation *from) -> Optional<UseRange> { 550 std::vector<SymbolUse> uses; 551 auto walkFn = [&](SymbolUse symbolUse, ArrayRef<int>) { 552 uses.push_back(symbolUse); 553 return WalkResult::advance(); 554 }; 555 auto result = walkSymbolUses(from, walkFn); 556 return result ? Optional<UseRange>(std::move(uses)) : Optional<UseRange>(); 557 } 558 559 //===----------------------------------------------------------------------===// 560 // SymbolTable::getSymbolUses 561 562 /// The implementation of SymbolTable::getSymbolUses below. 563 template <typename SymbolT> 564 static Optional<SymbolTable::UseRange> getSymbolUsesImpl(SymbolT symbol, 565 Operation *limit) { 566 std::vector<SymbolTable::SymbolUse> uses; 567 auto walkFn = [&](SymbolRefAttr symbolRefAttr, Operation *from) { 568 return walkSymbolUses( 569 from, [&](SymbolTable::SymbolUse symbolUse, ArrayRef<int>) { 570 if (isReferencePrefixOf(symbolRefAttr, symbolUse.getSymbolRef())) 571 uses.push_back(symbolUse); 572 return WalkResult::advance(); 573 }); 574 }; 575 if (walkSymbolScopes(symbol, limit, walkFn)) 576 return SymbolTable::UseRange(std::move(uses)); 577 return llvm::None; 578 } 579 580 /// Get all of the uses of the given symbol that are nested within the given 581 /// operation 'from', invoking the provided callback for each. This does not 582 /// traverse into any nested symbol tables, and will also only return uses on 583 /// 'from' if it does not also define a symbol table. This is because we treat 584 /// the region as the boundary of the symbol table, and not the op itself. This 585 /// function returns None if there are any unknown operations that may 586 /// potentially be symbol tables. 587 auto SymbolTable::getSymbolUses(StringRef symbol, Operation *from) 588 -> Optional<UseRange> { 589 return getSymbolUsesImpl(symbol, from); 590 } 591 auto SymbolTable::getSymbolUses(Operation *symbol, Operation *from) 592 -> Optional<UseRange> { 593 return getSymbolUsesImpl(symbol, from); 594 } 595 596 //===----------------------------------------------------------------------===// 597 // SymbolTable::symbolKnownUseEmpty 598 599 /// The implementation of SymbolTable::symbolKnownUseEmpty below. 600 template <typename SymbolT> 601 static bool symbolKnownUseEmptyImpl(SymbolT symbol, Operation *limit) { 602 // Walk all of the symbol uses looking for a reference to 'symbol'. 603 auto walkFn = [&](SymbolRefAttr symbolRefAttr, Operation *from) { 604 return walkSymbolUses( 605 from, [&](SymbolTable::SymbolUse symbolUse, ArrayRef<int>) { 606 return isReferencePrefixOf(symbolRefAttr, symbolUse.getSymbolRef()) 607 ? WalkResult::interrupt() 608 : WalkResult::advance(); 609 }); 610 }; 611 return walkSymbolScopes(symbol, limit, walkFn) == WalkResult::advance(); 612 } 613 614 /// Return if the given symbol is known to have no uses that are nested within 615 /// the given operation 'from'. This does not traverse into any nested symbol 616 /// tables, and will also only count uses on 'from' if it does not also define 617 /// a symbol table. This is because we treat the region as the boundary of the 618 /// symbol table, and not the op itself. This function will also return false if 619 /// there are any unknown operations that may potentially be symbol tables. 620 bool SymbolTable::symbolKnownUseEmpty(StringRef symbol, Operation *from) { 621 return symbolKnownUseEmptyImpl(symbol, from); 622 } 623 bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Operation *from) { 624 return symbolKnownUseEmptyImpl(symbol, from); 625 } 626 627 //===----------------------------------------------------------------------===// 628 // SymbolTable::replaceAllSymbolUses 629 630 /// Rebuild the given attribute container after replacing all references to a 631 /// symbol with the updated attribute in 'accesses'. 632 static Attribute rebuildAttrAfterRAUW( 633 Attribute container, 634 ArrayRef<std::pair<SmallVector<int, 1>, SymbolRefAttr>> accesses, 635 unsigned depth) { 636 // Given a range of Attributes, update the ones referred to by the given 637 // access chains to point to the new symbol attribute. 638 auto updateAttrs = [&](auto &&attrRange) { 639 auto attrBegin = std::begin(attrRange); 640 for (unsigned i = 0, e = accesses.size(); i != e;) { 641 ArrayRef<int> access = accesses[i].first; 642 Attribute &attr = *std::next(attrBegin, access[depth]); 643 644 // Check to see if this is a leaf access, i.e. a SymbolRef. 645 if (access.size() == depth + 1) { 646 attr = accesses[i].second; 647 ++i; 648 continue; 649 } 650 651 // Otherwise, this is a container. Collect all of the accesses for this 652 // index and recurse. The recursion here is bounded by the size of the 653 // largest access array. 654 auto nestedAccesses = accesses.drop_front(i).take_while([&](auto &it) { 655 ArrayRef<int> nextAccess = it.first; 656 return nextAccess.size() > depth + 1 && 657 nextAccess[depth] == access[depth]; 658 }); 659 attr = rebuildAttrAfterRAUW(attr, nestedAccesses, depth + 1); 660 661 // Skip over all of the accesses that refer to the nested container. 662 i += nestedAccesses.size(); 663 } 664 }; 665 666 if (auto dictAttr = container.dyn_cast<DictionaryAttr>()) { 667 auto newAttrs = llvm::to_vector<4>(dictAttr.getValue()); 668 updateAttrs(make_second_range(newAttrs)); 669 return DictionaryAttr::get(newAttrs, dictAttr.getContext()); 670 } 671 auto newAttrs = llvm::to_vector<4>(container.cast<ArrayAttr>().getValue()); 672 updateAttrs(newAttrs); 673 return ArrayAttr::get(newAttrs, container.getContext()); 674 } 675 676 /// Generates a new symbol reference attribute with a new leaf reference. 677 static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr, 678 FlatSymbolRefAttr newLeafAttr) { 679 if (oldAttr.isa<FlatSymbolRefAttr>()) 680 return newLeafAttr; 681 auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences()); 682 nestedRefs.back() = newLeafAttr; 683 return SymbolRefAttr::get(oldAttr.getRootReference(), nestedRefs, 684 oldAttr.getContext()); 685 } 686 687 /// The implementation of SymbolTable::replaceAllSymbolUses below. 688 template <typename SymbolT> 689 static LogicalResult replaceAllSymbolUsesImpl(SymbolT symbol, 690 StringRef newSymbol, 691 Operation *limit) { 692 // A collection of operations along with their new attribute dictionary. 693 std::vector<std::pair<Operation *, DictionaryAttr>> updatedAttrDicts; 694 695 // The current operation being processed. 696 Operation *curOp = nullptr; 697 698 // The set of access chains into the attribute dictionary of the current 699 // operation, as well as the replacement attribute to use. 700 SmallVector<std::pair<SmallVector<int, 1>, SymbolRefAttr>, 1> accessChains; 701 702 // Generate a new attribute dictionary for the current operation by replacing 703 // references to the old symbol. 704 auto generateNewAttrDict = [&] { 705 auto oldDict = curOp->getAttrList().getDictionary(); 706 auto newDict = rebuildAttrAfterRAUW(oldDict, accessChains, /*depth=*/0); 707 return newDict.cast<DictionaryAttr>(); 708 }; 709 710 // Generate a new attribute to replace the given attribute. 711 MLIRContext *ctx = limit->getContext(); 712 FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol, ctx); 713 auto scopeWalkFn = [&](SymbolRefAttr oldAttr, 714 Operation *from) -> Optional<WalkResult> { 715 SymbolRefAttr newAttr = generateNewRefAttr(oldAttr, newLeafAttr); 716 auto walkFn = [&](SymbolTable::SymbolUse symbolUse, 717 ArrayRef<int> accessChain) { 718 SymbolRefAttr useRef = symbolUse.getSymbolRef(); 719 if (!isReferencePrefixOf(oldAttr, useRef)) 720 return WalkResult::advance(); 721 722 // If we have a valid match, check to see if this is a proper 723 // subreference. If it is, then we will need to generate a different new 724 // attribute specifically for this use. 725 SymbolRefAttr replacementRef = newAttr; 726 if (useRef != oldAttr) { 727 if (oldAttr.isa<FlatSymbolRefAttr>()) { 728 replacementRef = 729 SymbolRefAttr::get(newSymbol, useRef.getNestedReferences(), ctx); 730 } else { 731 auto nestedRefs = llvm::to_vector<4>(useRef.getNestedReferences()); 732 nestedRefs[oldAttr.getNestedReferences().size() - 1] = newLeafAttr; 733 replacementRef = 734 SymbolRefAttr::get(useRef.getRootReference(), nestedRefs, ctx); 735 } 736 } 737 738 // If there was a previous operation, generate a new attribute dict 739 // for it. This means that we've finished processing the current 740 // operation, so generate a new dictionary for it. 741 if (curOp && symbolUse.getUser() != curOp) { 742 updatedAttrDicts.push_back({curOp, generateNewAttrDict()}); 743 accessChains.clear(); 744 } 745 746 // Record this access. 747 curOp = symbolUse.getUser(); 748 accessChains.push_back({llvm::to_vector<1>(accessChain), replacementRef}); 749 return WalkResult::advance(); 750 }; 751 if (!walkSymbolUses(from, walkFn)) 752 return llvm::None; 753 754 // Check to see if we have a dangling op that needs to be processed. 755 if (curOp) { 756 updatedAttrDicts.push_back({curOp, generateNewAttrDict()}); 757 curOp = nullptr; 758 } 759 return WalkResult::advance(); 760 }; 761 if (!walkSymbolScopes(symbol, limit, scopeWalkFn)) 762 return failure(); 763 764 // Update the attribute dictionaries as necessary. 765 for (auto &it : updatedAttrDicts) 766 it.first->setAttrs(it.second); 767 return success(); 768 } 769 770 /// Attempt to replace all uses of the given symbol 'oldSymbol' with the 771 /// provided symbol 'newSymbol' that are nested within the given operation 772 /// 'from'. This does not traverse into any nested symbol tables, and will 773 /// also only replace uses on 'from' if it does not also define a symbol 774 /// table. This is because we treat the region as the boundary of the symbol 775 /// table, and not the op itself. If there are any unknown operations that may 776 /// potentially be symbol tables, no uses are replaced and failure is returned. 777 LogicalResult SymbolTable::replaceAllSymbolUses(StringRef oldSymbol, 778 StringRef newSymbol, 779 Operation *from) { 780 return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from); 781 } 782 LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol, 783 StringRef newSymbol, 784 Operation *from) { 785 return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from); 786 } 787