1 //===- OperationSupport.cpp -----------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains out-of-line implementations of the support types that 10 // Operation and related classes build on top of. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/IR/OperationSupport.h" 15 #include "mlir/IR/BuiltinAttributes.h" 16 #include "mlir/IR/BuiltinTypes.h" 17 #include "mlir/IR/OpDefinition.h" 18 #include "llvm/ADT/BitVector.h" 19 #include <numeric> 20 21 using namespace mlir; 22 23 //===----------------------------------------------------------------------===// 24 // NamedAttrList 25 //===----------------------------------------------------------------------===// 26 27 NamedAttrList::NamedAttrList(ArrayRef<NamedAttribute> attributes) { 28 assign(attributes.begin(), attributes.end()); 29 } 30 31 NamedAttrList::NamedAttrList(DictionaryAttr attributes) 32 : NamedAttrList(attributes ? attributes.getValue() 33 : ArrayRef<NamedAttribute>()) { 34 dictionarySorted.setPointerAndInt(attributes, true); 35 } 36 37 NamedAttrList::NamedAttrList(const_iterator in_start, const_iterator in_end) { 38 assign(in_start, in_end); 39 } 40 41 ArrayRef<NamedAttribute> NamedAttrList::getAttrs() const { return attrs; } 42 43 Optional<NamedAttribute> NamedAttrList::findDuplicate() const { 44 Optional<NamedAttribute> duplicate = 45 DictionaryAttr::findDuplicate(attrs, isSorted()); 46 // DictionaryAttr::findDuplicate will sort the list, so reset the sorted 47 // state. 48 if (!isSorted()) 49 dictionarySorted.setPointerAndInt(nullptr, true); 50 return duplicate; 51 } 52 53 DictionaryAttr NamedAttrList::getDictionary(MLIRContext *context) const { 54 if (!isSorted()) { 55 DictionaryAttr::sortInPlace(attrs); 56 dictionarySorted.setPointerAndInt(nullptr, true); 57 } 58 if (!dictionarySorted.getPointer()) 59 dictionarySorted.setPointer(DictionaryAttr::getWithSorted(context, attrs)); 60 return dictionarySorted.getPointer().cast<DictionaryAttr>(); 61 } 62 63 /// Add an attribute with the specified name. 64 void NamedAttrList::append(StringRef name, Attribute attr) { 65 append(Identifier::get(name, attr.getContext()), attr); 66 } 67 68 /// Replaces the attributes with new list of attributes. 69 void NamedAttrList::assign(const_iterator in_start, const_iterator in_end) { 70 DictionaryAttr::sort(ArrayRef<NamedAttribute>{in_start, in_end}, attrs); 71 dictionarySorted.setPointerAndInt(nullptr, true); 72 } 73 74 void NamedAttrList::push_back(NamedAttribute newAttribute) { 75 assert(newAttribute.second && "unexpected null attribute"); 76 if (isSorted()) 77 dictionarySorted.setInt( 78 attrs.empty() || 79 strcmp(attrs.back().first.data(), newAttribute.first.data()) < 0); 80 dictionarySorted.setPointer(nullptr); 81 attrs.push_back(newAttribute); 82 } 83 84 /// Return the specified attribute if present, null otherwise. 85 Attribute NamedAttrList::get(StringRef name) const { 86 auto it = findAttr(*this, name); 87 return it.second ? it.first->second : Attribute(); 88 } 89 Attribute NamedAttrList::get(Identifier name) const { 90 auto it = findAttr(*this, name); 91 return it.second ? it.first->second : Attribute(); 92 } 93 94 /// Return the specified named attribute if present, None otherwise. 95 Optional<NamedAttribute> NamedAttrList::getNamed(StringRef name) const { 96 auto it = findAttr(*this, name); 97 return it.second ? *it.first : Optional<NamedAttribute>(); 98 } 99 Optional<NamedAttribute> NamedAttrList::getNamed(Identifier name) const { 100 auto it = findAttr(*this, name); 101 return it.second ? *it.first : Optional<NamedAttribute>(); 102 } 103 104 /// If the an attribute exists with the specified name, change it to the new 105 /// value. Otherwise, add a new attribute with the specified name/value. 106 Attribute NamedAttrList::set(Identifier name, Attribute value) { 107 assert(value && "attributes may never be null"); 108 109 // Look for an existing attribute with the given name, and set its value 110 // in-place. Return the previous value of the attribute, if there was one. 111 auto it = findAttr(*this, name); 112 if (it.second) { 113 // Update the existing attribute by swapping out the old value for the new 114 // value. Return the old value. 115 if (it.first->second != value) { 116 std::swap(it.first->second, value); 117 // If the attributes have changed, the dictionary is invalidated. 118 dictionarySorted.setPointer(nullptr); 119 } 120 return value; 121 } 122 // Perform a string lookup to insert the new attribute into its sorted 123 // position. 124 if (isSorted()) 125 it = findAttr(*this, name.strref()); 126 attrs.insert(it.first, {name, value}); 127 // Invalidate the dictionary. Return null as there was no previous value. 128 dictionarySorted.setPointer(nullptr); 129 return Attribute(); 130 } 131 132 Attribute NamedAttrList::set(StringRef name, Attribute value) { 133 assert(value && "attributes may never be null"); 134 return set(mlir::Identifier::get(name, value.getContext()), value); 135 } 136 137 Attribute 138 NamedAttrList::eraseImpl(SmallVectorImpl<NamedAttribute>::iterator it) { 139 // Erasing does not affect the sorted property. 140 Attribute attr = it->second; 141 attrs.erase(it); 142 dictionarySorted.setPointer(nullptr); 143 return attr; 144 } 145 146 Attribute NamedAttrList::erase(Identifier name) { 147 auto it = findAttr(*this, name); 148 return it.second ? eraseImpl(it.first) : Attribute(); 149 } 150 151 Attribute NamedAttrList::erase(StringRef name) { 152 auto it = findAttr(*this, name); 153 return it.second ? eraseImpl(it.first) : Attribute(); 154 } 155 156 NamedAttrList & 157 NamedAttrList::operator=(const SmallVectorImpl<NamedAttribute> &rhs) { 158 assign(rhs.begin(), rhs.end()); 159 return *this; 160 } 161 162 NamedAttrList::operator ArrayRef<NamedAttribute>() const { return attrs; } 163 164 //===----------------------------------------------------------------------===// 165 // OperationState 166 //===----------------------------------------------------------------------===// 167 168 OperationState::OperationState(Location location, StringRef name) 169 : location(location), name(name, location->getContext()) {} 170 171 OperationState::OperationState(Location location, OperationName name) 172 : location(location), name(name) {} 173 174 OperationState::OperationState(Location location, StringRef name, 175 ValueRange operands, TypeRange types, 176 ArrayRef<NamedAttribute> attributes, 177 BlockRange successors, 178 MutableArrayRef<std::unique_ptr<Region>> regions) 179 : location(location), name(name, location->getContext()), 180 operands(operands.begin(), operands.end()), 181 types(types.begin(), types.end()), 182 attributes(attributes.begin(), attributes.end()), 183 successors(successors.begin(), successors.end()) { 184 for (std::unique_ptr<Region> &r : regions) 185 this->regions.push_back(std::move(r)); 186 } 187 188 void OperationState::addOperands(ValueRange newOperands) { 189 operands.append(newOperands.begin(), newOperands.end()); 190 } 191 192 void OperationState::addSuccessors(BlockRange newSuccessors) { 193 successors.append(newSuccessors.begin(), newSuccessors.end()); 194 } 195 196 Region *OperationState::addRegion() { 197 regions.emplace_back(new Region); 198 return regions.back().get(); 199 } 200 201 void OperationState::addRegion(std::unique_ptr<Region> &®ion) { 202 regions.push_back(std::move(region)); 203 } 204 205 void OperationState::addRegions( 206 MutableArrayRef<std::unique_ptr<Region>> regions) { 207 for (std::unique_ptr<Region> ®ion : regions) 208 addRegion(std::move(region)); 209 } 210 211 //===----------------------------------------------------------------------===// 212 // OperandStorage 213 //===----------------------------------------------------------------------===// 214 215 detail::OperandStorage::OperandStorage(Operation *owner, 216 OpOperand *trailingOperands, 217 ValueRange values) 218 : isStorageDynamic(false), operandStorage(trailingOperands) { 219 numOperands = capacity = values.size(); 220 for (unsigned i = 0; i < numOperands; ++i) 221 new (&operandStorage[i]) OpOperand(owner, values[i]); 222 } 223 224 detail::OperandStorage::~OperandStorage() { 225 for (auto &operand : getOperands()) 226 operand.~OpOperand(); 227 228 // If the storage is dynamic, deallocate it. 229 if (isStorageDynamic) 230 free(operandStorage); 231 } 232 233 /// Replace the operands contained in the storage with the ones provided in 234 /// 'values'. 235 void detail::OperandStorage::setOperands(Operation *owner, ValueRange values) { 236 MutableArrayRef<OpOperand> storageOperands = resize(owner, values.size()); 237 for (unsigned i = 0, e = values.size(); i != e; ++i) 238 storageOperands[i].set(values[i]); 239 } 240 241 /// Replace the operands beginning at 'start' and ending at 'start' + 'length' 242 /// with the ones provided in 'operands'. 'operands' may be smaller or larger 243 /// than the range pointed to by 'start'+'length'. 244 void detail::OperandStorage::setOperands(Operation *owner, unsigned start, 245 unsigned length, ValueRange operands) { 246 // If the new size is the same, we can update inplace. 247 unsigned newSize = operands.size(); 248 if (newSize == length) { 249 MutableArrayRef<OpOperand> storageOperands = getOperands(); 250 for (unsigned i = 0, e = length; i != e; ++i) 251 storageOperands[start + i].set(operands[i]); 252 return; 253 } 254 // If the new size is greater, remove the extra operands and set the rest 255 // inplace. 256 if (newSize < length) { 257 eraseOperands(start + operands.size(), length - newSize); 258 setOperands(owner, start, newSize, operands); 259 return; 260 } 261 // Otherwise, the new size is greater so we need to grow the storage. 262 auto storageOperands = resize(owner, size() + (newSize - length)); 263 264 // Shift operands to the right to make space for the new operands. 265 unsigned rotateSize = storageOperands.size() - (start + length); 266 auto rbegin = storageOperands.rbegin(); 267 std::rotate(rbegin, std::next(rbegin, newSize - length), rbegin + rotateSize); 268 269 // Update the operands inplace. 270 for (unsigned i = 0, e = operands.size(); i != e; ++i) 271 storageOperands[start + i].set(operands[i]); 272 } 273 274 /// Erase an operand held by the storage. 275 void detail::OperandStorage::eraseOperands(unsigned start, unsigned length) { 276 MutableArrayRef<OpOperand> operands = getOperands(); 277 assert((start + length) <= operands.size()); 278 numOperands -= length; 279 280 // Shift all operands down if the operand to remove is not at the end. 281 if (start != numOperands) { 282 auto *indexIt = std::next(operands.begin(), start); 283 std::rotate(indexIt, std::next(indexIt, length), operands.end()); 284 } 285 for (unsigned i = 0; i != length; ++i) 286 operands[numOperands + i].~OpOperand(); 287 } 288 289 void detail::OperandStorage::eraseOperands( 290 const llvm::BitVector &eraseIndices) { 291 MutableArrayRef<OpOperand> operands = getOperands(); 292 assert(eraseIndices.size() == operands.size()); 293 294 // Check that at least one operand is erased. 295 int firstErasedIndice = eraseIndices.find_first(); 296 if (firstErasedIndice == -1) 297 return; 298 299 // Shift all of the removed operands to the end, and destroy them. 300 numOperands = firstErasedIndice; 301 for (unsigned i = firstErasedIndice + 1, e = operands.size(); i < e; ++i) 302 if (!eraseIndices.test(i)) 303 operands[numOperands++] = std::move(operands[i]); 304 for (OpOperand &operand : operands.drop_front(numOperands)) 305 operand.~OpOperand(); 306 } 307 308 /// Resize the storage to the given size. Returns the array containing the new 309 /// operands. 310 MutableArrayRef<OpOperand> detail::OperandStorage::resize(Operation *owner, 311 unsigned newSize) { 312 // If the number of operands is less than or equal to the current amount, we 313 // can just update in place. 314 MutableArrayRef<OpOperand> origOperands = getOperands(); 315 if (newSize <= numOperands) { 316 // If the number of new size is less than the current, remove any extra 317 // operands. 318 for (unsigned i = newSize; i != numOperands; ++i) 319 origOperands[i].~OpOperand(); 320 numOperands = newSize; 321 return origOperands.take_front(newSize); 322 } 323 324 // If the new size is within the original inline capacity, grow inplace. 325 if (newSize <= capacity) { 326 OpOperand *opBegin = origOperands.data(); 327 for (unsigned e = newSize; numOperands != e; ++numOperands) 328 new (&opBegin[numOperands]) OpOperand(owner); 329 return MutableArrayRef<OpOperand>(opBegin, newSize); 330 } 331 332 // Otherwise, we need to allocate a new storage. 333 unsigned newCapacity = 334 std::max(unsigned(llvm::NextPowerOf2(capacity + 2)), newSize); 335 OpOperand *newOperandStorage = 336 reinterpret_cast<OpOperand *>(malloc(sizeof(OpOperand) * newCapacity)); 337 338 // Move the current operands to the new storage. 339 MutableArrayRef<OpOperand> newOperands(newOperandStorage, newSize); 340 std::uninitialized_copy(std::make_move_iterator(origOperands.begin()), 341 std::make_move_iterator(origOperands.end()), 342 newOperands.begin()); 343 344 // Destroy the original operands. 345 for (auto &operand : origOperands) 346 operand.~OpOperand(); 347 348 // Initialize any new operands. 349 for (unsigned e = newSize; numOperands != e; ++numOperands) 350 new (&newOperands[numOperands]) OpOperand(owner); 351 352 // If the current storage is dynamic, free it. 353 if (isStorageDynamic) 354 free(operandStorage); 355 356 // Update the storage representation to use the new dynamic storage. 357 operandStorage = newOperandStorage; 358 capacity = newCapacity; 359 isStorageDynamic = true; 360 return newOperands; 361 } 362 363 //===----------------------------------------------------------------------===// 364 // Operation Value-Iterators 365 //===----------------------------------------------------------------------===// 366 367 //===----------------------------------------------------------------------===// 368 // OperandRange 369 370 unsigned OperandRange::getBeginOperandIndex() const { 371 assert(!empty() && "range must not be empty"); 372 return base->getOperandNumber(); 373 } 374 375 OperandRangeRange OperandRange::split(ElementsAttr segmentSizes) const { 376 return OperandRangeRange(*this, segmentSizes); 377 } 378 379 //===----------------------------------------------------------------------===// 380 // OperandRangeRange 381 382 OperandRangeRange::OperandRangeRange(OperandRange operands, 383 Attribute operandSegments) 384 : OperandRangeRange(OwnerT(operands.getBase(), operandSegments), 0, 385 operandSegments.cast<DenseElementsAttr>().size()) {} 386 387 OperandRange OperandRangeRange::join() const { 388 const OwnerT &owner = getBase(); 389 auto sizeData = owner.second.cast<DenseElementsAttr>().getValues<uint32_t>(); 390 return OperandRange(owner.first, 391 std::accumulate(sizeData.begin(), sizeData.end(), 0)); 392 } 393 394 OperandRange OperandRangeRange::dereference(const OwnerT &object, 395 ptrdiff_t index) { 396 auto sizeData = object.second.cast<DenseElementsAttr>().getValues<uint32_t>(); 397 uint32_t startIndex = 398 std::accumulate(sizeData.begin(), sizeData.begin() + index, 0); 399 return OperandRange(object.first + startIndex, *(sizeData.begin() + index)); 400 } 401 402 //===----------------------------------------------------------------------===// 403 // MutableOperandRange 404 405 /// Construct a new mutable range from the given operand, operand start index, 406 /// and range length. 407 MutableOperandRange::MutableOperandRange( 408 Operation *owner, unsigned start, unsigned length, 409 ArrayRef<OperandSegment> operandSegments) 410 : owner(owner), start(start), length(length), 411 operandSegments(operandSegments.begin(), operandSegments.end()) { 412 assert((start + length) <= owner->getNumOperands() && "invalid range"); 413 } 414 MutableOperandRange::MutableOperandRange(Operation *owner) 415 : MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {} 416 417 /// Slice this range into a sub range, with the additional operand segment. 418 MutableOperandRange 419 MutableOperandRange::slice(unsigned subStart, unsigned subLen, 420 Optional<OperandSegment> segment) const { 421 assert((subStart + subLen) <= length && "invalid sub-range"); 422 MutableOperandRange subSlice(owner, start + subStart, subLen, 423 operandSegments); 424 if (segment) 425 subSlice.operandSegments.push_back(*segment); 426 return subSlice; 427 } 428 429 /// Append the given values to the range. 430 void MutableOperandRange::append(ValueRange values) { 431 if (values.empty()) 432 return; 433 owner->insertOperands(start + length, values); 434 updateLength(length + values.size()); 435 } 436 437 /// Assign this range to the given values. 438 void MutableOperandRange::assign(ValueRange values) { 439 owner->setOperands(start, length, values); 440 if (length != values.size()) 441 updateLength(/*newLength=*/values.size()); 442 } 443 444 /// Assign the range to the given value. 445 void MutableOperandRange::assign(Value value) { 446 if (length == 1) { 447 owner->setOperand(start, value); 448 } else { 449 owner->setOperands(start, length, value); 450 updateLength(/*newLength=*/1); 451 } 452 } 453 454 /// Erase the operands within the given sub-range. 455 void MutableOperandRange::erase(unsigned subStart, unsigned subLen) { 456 assert((subStart + subLen) <= length && "invalid sub-range"); 457 if (length == 0) 458 return; 459 owner->eraseOperands(start + subStart, subLen); 460 updateLength(length - subLen); 461 } 462 463 /// Clear this range and erase all of the operands. 464 void MutableOperandRange::clear() { 465 if (length != 0) { 466 owner->eraseOperands(start, length); 467 updateLength(/*newLength=*/0); 468 } 469 } 470 471 /// Allow implicit conversion to an OperandRange. 472 MutableOperandRange::operator OperandRange() const { 473 return owner->getOperands().slice(start, length); 474 } 475 476 MutableOperandRangeRange 477 MutableOperandRange::split(NamedAttribute segmentSizes) const { 478 return MutableOperandRangeRange(*this, segmentSizes); 479 } 480 481 /// Update the length of this range to the one provided. 482 void MutableOperandRange::updateLength(unsigned newLength) { 483 int32_t diff = int32_t(newLength) - int32_t(length); 484 length = newLength; 485 486 // Update any of the provided segment attributes. 487 for (OperandSegment &segment : operandSegments) { 488 auto attr = segment.second.second.cast<DenseIntElementsAttr>(); 489 SmallVector<int32_t, 8> segments(attr.getValues<int32_t>()); 490 segments[segment.first] += diff; 491 segment.second.second = DenseIntElementsAttr::get(attr.getType(), segments); 492 owner->setAttr(segment.second.first, segment.second.second); 493 } 494 } 495 496 //===----------------------------------------------------------------------===// 497 // MutableOperandRangeRange 498 499 MutableOperandRangeRange::MutableOperandRangeRange( 500 const MutableOperandRange &operands, NamedAttribute operandSegmentAttr) 501 : MutableOperandRangeRange( 502 OwnerT(operands, operandSegmentAttr), 0, 503 operandSegmentAttr.second.cast<DenseElementsAttr>().size()) {} 504 505 MutableOperandRange MutableOperandRangeRange::join() const { 506 return getBase().first; 507 } 508 509 MutableOperandRangeRange::operator OperandRangeRange() const { 510 return OperandRangeRange(getBase().first, 511 getBase().second.second.cast<DenseElementsAttr>()); 512 } 513 514 MutableOperandRange MutableOperandRangeRange::dereference(const OwnerT &object, 515 ptrdiff_t index) { 516 auto sizeData = 517 object.second.second.cast<DenseElementsAttr>().getValues<uint32_t>(); 518 uint32_t startIndex = 519 std::accumulate(sizeData.begin(), sizeData.begin() + index, 0); 520 return object.first.slice( 521 startIndex, *(sizeData.begin() + index), 522 MutableOperandRange::OperandSegment(index, object.second)); 523 } 524 525 //===----------------------------------------------------------------------===// 526 // ResultRange 527 528 ResultRange::ResultRange(OpResult result) 529 : ResultRange(static_cast<detail::OpResultImpl *>(Value(result).getImpl()), 530 1) {} 531 532 ResultRange::use_range ResultRange::getUses() const { 533 return {use_begin(), use_end()}; 534 } 535 ResultRange::use_iterator ResultRange::use_begin() const { 536 return use_iterator(*this); 537 } 538 ResultRange::use_iterator ResultRange::use_end() const { 539 return use_iterator(*this, /*end=*/true); 540 } 541 ResultRange::user_range ResultRange::getUsers() { 542 return {user_begin(), user_end()}; 543 } 544 ResultRange::user_iterator ResultRange::user_begin() { 545 return user_iterator(use_begin()); 546 } 547 ResultRange::user_iterator ResultRange::user_end() { 548 return user_iterator(use_end()); 549 } 550 551 ResultRange::UseIterator::UseIterator(ResultRange results, bool end) 552 : it(end ? results.end() : results.begin()), endIt(results.end()) { 553 // Only initialize current use if there are results/can be uses. 554 if (it != endIt) 555 skipOverResultsWithNoUsers(); 556 } 557 558 ResultRange::UseIterator &ResultRange::UseIterator::operator++() { 559 // We increment over uses, if we reach the last use then move to next 560 // result. 561 if (use != (*it).use_end()) 562 ++use; 563 if (use == (*it).use_end()) { 564 ++it; 565 skipOverResultsWithNoUsers(); 566 } 567 return *this; 568 } 569 570 void ResultRange::UseIterator::skipOverResultsWithNoUsers() { 571 while (it != endIt && (*it).use_empty()) 572 ++it; 573 574 // If we are at the last result, then set use to first use of 575 // first result (sentinel value used for end). 576 if (it == endIt) 577 use = {}; 578 else 579 use = (*it).use_begin(); 580 } 581 582 void ResultRange::replaceAllUsesWith(Operation *op) { 583 replaceAllUsesWith(op->getResults()); 584 } 585 586 //===----------------------------------------------------------------------===// 587 // ValueRange 588 589 ValueRange::ValueRange(ArrayRef<Value> values) 590 : ValueRange(values.data(), values.size()) {} 591 ValueRange::ValueRange(OperandRange values) 592 : ValueRange(values.begin().getBase(), values.size()) {} 593 ValueRange::ValueRange(ResultRange values) 594 : ValueRange(values.getBase(), values.size()) {} 595 596 /// See `llvm::detail::indexed_accessor_range_base` for details. 597 ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner, 598 ptrdiff_t index) { 599 if (const auto *value = owner.dyn_cast<const Value *>()) 600 return {value + index}; 601 if (auto *operand = owner.dyn_cast<OpOperand *>()) 602 return {operand + index}; 603 return owner.get<detail::OpResultImpl *>()->getNextResultAtOffset(index); 604 } 605 /// See `llvm::detail::indexed_accessor_range_base` for details. 606 Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) { 607 if (const auto *value = owner.dyn_cast<const Value *>()) 608 return value[index]; 609 if (auto *operand = owner.dyn_cast<OpOperand *>()) 610 return operand[index].get(); 611 return owner.get<detail::OpResultImpl *>()->getNextResultAtOffset(index); 612 } 613 614 //===----------------------------------------------------------------------===// 615 // Operation Equivalency 616 //===----------------------------------------------------------------------===// 617 618 llvm::hash_code OperationEquivalence::computeHash( 619 Operation *op, function_ref<llvm::hash_code(Value)> hashOperands, 620 function_ref<llvm::hash_code(Value)> hashResults, Flags flags) { 621 // Hash operations based upon their: 622 // - Operation Name 623 // - Attributes 624 // - Result Types 625 llvm::hash_code hash = llvm::hash_combine( 626 op->getName(), op->getAttrDictionary(), op->getResultTypes()); 627 628 // - Operands 629 for (Value operand : op->getOperands()) 630 hash = llvm::hash_combine(hash, hashOperands(operand)); 631 // - Operands 632 for (Value result : op->getResults()) 633 hash = llvm::hash_combine(hash, hashResults(result)); 634 return hash; 635 } 636 637 static bool 638 isRegionEquivalentTo(Region *lhs, Region *rhs, 639 function_ref<LogicalResult(Value, Value)> mapOperands, 640 function_ref<LogicalResult(Value, Value)> mapResults, 641 OperationEquivalence::Flags flags) { 642 DenseMap<Block *, Block *> blocksMap; 643 auto blocksEquivalent = [&](Block &lBlock, Block &rBlock) { 644 // Check block arguments. 645 if (lBlock.getNumArguments() != rBlock.getNumArguments()) 646 return false; 647 648 // Map the two blocks. 649 auto insertion = blocksMap.insert({&lBlock, &rBlock}); 650 if (insertion.first->getSecond() != &rBlock) 651 return false; 652 653 for (auto argPair : 654 llvm::zip(lBlock.getArguments(), rBlock.getArguments())) { 655 Value curArg = std::get<0>(argPair); 656 Value otherArg = std::get<1>(argPair); 657 if (curArg.getType() != otherArg.getType()) 658 return false; 659 if (!(flags & OperationEquivalence::IgnoreLocations) && 660 curArg.getLoc() != otherArg.getLoc()) 661 return false; 662 // Check if this value was already mapped to another value. 663 if (failed(mapOperands(curArg, otherArg))) 664 return false; 665 } 666 667 auto opsEquivalent = [&](Operation &lOp, Operation &rOp) { 668 // Check for op equality (recursively). 669 if (!OperationEquivalence::isEquivalentTo(&lOp, &rOp, mapOperands, 670 mapResults, flags)) 671 return false; 672 // Check successor mapping. 673 for (auto successorsPair : 674 llvm::zip(lOp.getSuccessors(), rOp.getSuccessors())) { 675 Block *curSuccessor = std::get<0>(successorsPair); 676 Block *otherSuccessor = std::get<1>(successorsPair); 677 auto insertion = blocksMap.insert({curSuccessor, otherSuccessor}); 678 if (insertion.first->getSecond() != otherSuccessor) 679 return false; 680 } 681 return true; 682 }; 683 return llvm::all_of_zip(lBlock, rBlock, opsEquivalent); 684 }; 685 return llvm::all_of_zip(*lhs, *rhs, blocksEquivalent); 686 } 687 688 bool OperationEquivalence::isEquivalentTo( 689 Operation *lhs, Operation *rhs, 690 function_ref<LogicalResult(Value, Value)> mapOperands, 691 function_ref<LogicalResult(Value, Value)> mapResults, Flags flags) { 692 if (lhs == rhs) 693 return true; 694 695 // Compare the operation properties. 696 if (lhs->getName() != rhs->getName() || 697 lhs->getAttrDictionary() != rhs->getAttrDictionary() || 698 lhs->getNumRegions() != rhs->getNumRegions() || 699 lhs->getNumSuccessors() != rhs->getNumSuccessors() || 700 lhs->getNumOperands() != rhs->getNumOperands() || 701 lhs->getNumResults() != rhs->getNumResults()) 702 return false; 703 if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc()) 704 return false; 705 706 auto checkValueRangeMapping = 707 [](ValueRange lhs, ValueRange rhs, 708 function_ref<LogicalResult(Value, Value)> mapValues) { 709 for (auto operandPair : llvm::zip(lhs, rhs)) { 710 Value curArg = std::get<0>(operandPair); 711 Value otherArg = std::get<1>(operandPair); 712 if (curArg.getType() != otherArg.getType()) 713 return false; 714 if (failed(mapValues(curArg, otherArg))) 715 return false; 716 } 717 return true; 718 }; 719 // Check mapping of operands and results. 720 if (!checkValueRangeMapping(lhs->getOperands(), rhs->getOperands(), 721 mapOperands)) 722 return false; 723 if (!checkValueRangeMapping(lhs->getResults(), rhs->getResults(), mapResults)) 724 return false; 725 for (auto regionPair : llvm::zip(lhs->getRegions(), rhs->getRegions())) 726 if (!isRegionEquivalentTo(&std::get<0>(regionPair), 727 &std::get<1>(regionPair), mapOperands, mapResults, 728 flags)) 729 return false; 730 return true; 731 } 732