1 //===- OperationSupport.cpp -----------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains out-of-line implementations of the support types that 10 // Operation and related classes build on top of. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/IR/OperationSupport.h" 15 #include "mlir/IR/BuiltinAttributes.h" 16 #include "mlir/IR/BuiltinTypes.h" 17 #include "mlir/IR/OpDefinition.h" 18 #include "llvm/ADT/BitVector.h" 19 #include <numeric> 20 21 using namespace mlir; 22 23 //===----------------------------------------------------------------------===// 24 // NamedAttrList 25 //===----------------------------------------------------------------------===// 26 27 NamedAttrList::NamedAttrList(ArrayRef<NamedAttribute> attributes) { 28 assign(attributes.begin(), attributes.end()); 29 } 30 31 NamedAttrList::NamedAttrList(DictionaryAttr attributes) 32 : NamedAttrList(attributes ? attributes.getValue() 33 : ArrayRef<NamedAttribute>()) { 34 dictionarySorted.setPointerAndInt(attributes, true); 35 } 36 37 NamedAttrList::NamedAttrList(const_iterator inStart, const_iterator inEnd) { 38 assign(inStart, inEnd); 39 } 40 41 ArrayRef<NamedAttribute> NamedAttrList::getAttrs() const { return attrs; } 42 43 Optional<NamedAttribute> NamedAttrList::findDuplicate() const { 44 Optional<NamedAttribute> duplicate = 45 DictionaryAttr::findDuplicate(attrs, isSorted()); 46 // DictionaryAttr::findDuplicate will sort the list, so reset the sorted 47 // state. 48 if (!isSorted()) 49 dictionarySorted.setPointerAndInt(nullptr, true); 50 return duplicate; 51 } 52 53 DictionaryAttr NamedAttrList::getDictionary(MLIRContext *context) const { 54 if (!isSorted()) { 55 DictionaryAttr::sortInPlace(attrs); 56 dictionarySorted.setPointerAndInt(nullptr, true); 57 } 58 if (!dictionarySorted.getPointer()) 59 dictionarySorted.setPointer(DictionaryAttr::getWithSorted(context, attrs)); 60 return dictionarySorted.getPointer().cast<DictionaryAttr>(); 61 } 62 63 /// Add an attribute with the specified name. 64 void NamedAttrList::append(StringRef name, Attribute attr) { 65 append(StringAttr::get(attr.getContext(), name), attr); 66 } 67 68 /// Replaces the attributes with new list of attributes. 69 void NamedAttrList::assign(const_iterator inStart, const_iterator inEnd) { 70 DictionaryAttr::sort(ArrayRef<NamedAttribute>{inStart, inEnd}, attrs); 71 dictionarySorted.setPointerAndInt(nullptr, true); 72 } 73 74 void NamedAttrList::push_back(NamedAttribute newAttribute) { 75 if (isSorted()) 76 dictionarySorted.setInt(attrs.empty() || attrs.back() < newAttribute); 77 dictionarySorted.setPointer(nullptr); 78 attrs.push_back(newAttribute); 79 } 80 81 /// Return the specified attribute if present, null otherwise. 82 Attribute NamedAttrList::get(StringRef name) const { 83 auto it = findAttr(*this, name); 84 return it.second ? it.first->getValue() : Attribute(); 85 } 86 Attribute NamedAttrList::get(StringAttr name) const { 87 auto it = findAttr(*this, name); 88 return it.second ? it.first->getValue() : Attribute(); 89 } 90 91 /// Return the specified named attribute if present, None otherwise. 92 Optional<NamedAttribute> NamedAttrList::getNamed(StringRef name) const { 93 auto it = findAttr(*this, name); 94 return it.second ? *it.first : Optional<NamedAttribute>(); 95 } 96 Optional<NamedAttribute> NamedAttrList::getNamed(StringAttr name) const { 97 auto it = findAttr(*this, name); 98 return it.second ? *it.first : Optional<NamedAttribute>(); 99 } 100 101 /// If the an attribute exists with the specified name, change it to the new 102 /// value. Otherwise, add a new attribute with the specified name/value. 103 Attribute NamedAttrList::set(StringAttr name, Attribute value) { 104 assert(value && "attributes may never be null"); 105 106 // Look for an existing attribute with the given name, and set its value 107 // in-place. Return the previous value of the attribute, if there was one. 108 auto it = findAttr(*this, name); 109 if (it.second) { 110 // Update the existing attribute by swapping out the old value for the new 111 // value. Return the old value. 112 Attribute oldValue = it.first->getValue(); 113 if (it.first->getValue() != value) { 114 it.first->setValue(value); 115 116 // If the attributes have changed, the dictionary is invalidated. 117 dictionarySorted.setPointer(nullptr); 118 } 119 return oldValue; 120 } 121 // Perform a string lookup to insert the new attribute into its sorted 122 // position. 123 if (isSorted()) 124 it = findAttr(*this, name.strref()); 125 attrs.insert(it.first, {name, value}); 126 // Invalidate the dictionary. Return null as there was no previous value. 127 dictionarySorted.setPointer(nullptr); 128 return Attribute(); 129 } 130 131 Attribute NamedAttrList::set(StringRef name, Attribute value) { 132 assert(value && "attributes may never be null"); 133 return set(mlir::StringAttr::get(value.getContext(), name), value); 134 } 135 136 Attribute 137 NamedAttrList::eraseImpl(SmallVectorImpl<NamedAttribute>::iterator it) { 138 // Erasing does not affect the sorted property. 139 Attribute attr = it->getValue(); 140 attrs.erase(it); 141 dictionarySorted.setPointer(nullptr); 142 return attr; 143 } 144 145 Attribute NamedAttrList::erase(StringAttr name) { 146 auto it = findAttr(*this, name); 147 return it.second ? eraseImpl(it.first) : Attribute(); 148 } 149 150 Attribute NamedAttrList::erase(StringRef name) { 151 auto it = findAttr(*this, name); 152 return it.second ? eraseImpl(it.first) : Attribute(); 153 } 154 155 NamedAttrList & 156 NamedAttrList::operator=(const SmallVectorImpl<NamedAttribute> &rhs) { 157 assign(rhs.begin(), rhs.end()); 158 return *this; 159 } 160 161 NamedAttrList::operator ArrayRef<NamedAttribute>() const { return attrs; } 162 163 //===----------------------------------------------------------------------===// 164 // OperationState 165 //===----------------------------------------------------------------------===// 166 167 OperationState::OperationState(Location location, StringRef name) 168 : location(location), name(name, location->getContext()) {} 169 170 OperationState::OperationState(Location location, OperationName name) 171 : location(location), name(name) {} 172 173 OperationState::OperationState(Location location, OperationName name, 174 ValueRange operands, TypeRange types, 175 ArrayRef<NamedAttribute> attributes, 176 BlockRange successors, 177 MutableArrayRef<std::unique_ptr<Region>> regions) 178 : location(location), name(name), 179 operands(operands.begin(), operands.end()), 180 types(types.begin(), types.end()), 181 attributes(attributes.begin(), attributes.end()), 182 successors(successors.begin(), successors.end()) { 183 for (std::unique_ptr<Region> &r : regions) 184 this->regions.push_back(std::move(r)); 185 } 186 OperationState::OperationState(Location location, StringRef name, 187 ValueRange operands, TypeRange types, 188 ArrayRef<NamedAttribute> attributes, 189 BlockRange successors, 190 MutableArrayRef<std::unique_ptr<Region>> regions) 191 : OperationState(location, OperationName(name, location.getContext()), 192 operands, types, attributes, successors, regions) {} 193 194 void OperationState::addOperands(ValueRange newOperands) { 195 operands.append(newOperands.begin(), newOperands.end()); 196 } 197 198 void OperationState::addSuccessors(BlockRange newSuccessors) { 199 successors.append(newSuccessors.begin(), newSuccessors.end()); 200 } 201 202 Region *OperationState::addRegion() { 203 regions.emplace_back(new Region); 204 return regions.back().get(); 205 } 206 207 void OperationState::addRegion(std::unique_ptr<Region> &®ion) { 208 regions.push_back(std::move(region)); 209 } 210 211 void OperationState::addRegions( 212 MutableArrayRef<std::unique_ptr<Region>> regions) { 213 for (std::unique_ptr<Region> ®ion : regions) 214 addRegion(std::move(region)); 215 } 216 217 //===----------------------------------------------------------------------===// 218 // OperandStorage 219 //===----------------------------------------------------------------------===// 220 221 detail::OperandStorage::OperandStorage(Operation *owner, 222 OpOperand *trailingOperands, 223 ValueRange values) 224 : isStorageDynamic(false), operandStorage(trailingOperands) { 225 numOperands = capacity = values.size(); 226 for (unsigned i = 0; i < numOperands; ++i) 227 new (&operandStorage[i]) OpOperand(owner, values[i]); 228 } 229 230 detail::OperandStorage::~OperandStorage() { 231 for (auto &operand : getOperands()) 232 operand.~OpOperand(); 233 234 // If the storage is dynamic, deallocate it. 235 if (isStorageDynamic) 236 free(operandStorage); 237 } 238 239 /// Replace the operands contained in the storage with the ones provided in 240 /// 'values'. 241 void detail::OperandStorage::setOperands(Operation *owner, ValueRange values) { 242 MutableArrayRef<OpOperand> storageOperands = resize(owner, values.size()); 243 for (unsigned i = 0, e = values.size(); i != e; ++i) 244 storageOperands[i].set(values[i]); 245 } 246 247 /// Replace the operands beginning at 'start' and ending at 'start' + 'length' 248 /// with the ones provided in 'operands'. 'operands' may be smaller or larger 249 /// than the range pointed to by 'start'+'length'. 250 void detail::OperandStorage::setOperands(Operation *owner, unsigned start, 251 unsigned length, ValueRange operands) { 252 // If the new size is the same, we can update inplace. 253 unsigned newSize = operands.size(); 254 if (newSize == length) { 255 MutableArrayRef<OpOperand> storageOperands = getOperands(); 256 for (unsigned i = 0, e = length; i != e; ++i) 257 storageOperands[start + i].set(operands[i]); 258 return; 259 } 260 // If the new size is greater, remove the extra operands and set the rest 261 // inplace. 262 if (newSize < length) { 263 eraseOperands(start + operands.size(), length - newSize); 264 setOperands(owner, start, newSize, operands); 265 return; 266 } 267 // Otherwise, the new size is greater so we need to grow the storage. 268 auto storageOperands = resize(owner, size() + (newSize - length)); 269 270 // Shift operands to the right to make space for the new operands. 271 unsigned rotateSize = storageOperands.size() - (start + length); 272 auto rbegin = storageOperands.rbegin(); 273 std::rotate(rbegin, std::next(rbegin, newSize - length), rbegin + rotateSize); 274 275 // Update the operands inplace. 276 for (unsigned i = 0, e = operands.size(); i != e; ++i) 277 storageOperands[start + i].set(operands[i]); 278 } 279 280 /// Erase an operand held by the storage. 281 void detail::OperandStorage::eraseOperands(unsigned start, unsigned length) { 282 MutableArrayRef<OpOperand> operands = getOperands(); 283 assert((start + length) <= operands.size()); 284 numOperands -= length; 285 286 // Shift all operands down if the operand to remove is not at the end. 287 if (start != numOperands) { 288 auto *indexIt = std::next(operands.begin(), start); 289 std::rotate(indexIt, std::next(indexIt, length), operands.end()); 290 } 291 for (unsigned i = 0; i != length; ++i) 292 operands[numOperands + i].~OpOperand(); 293 } 294 295 void detail::OperandStorage::eraseOperands(const BitVector &eraseIndices) { 296 MutableArrayRef<OpOperand> operands = getOperands(); 297 assert(eraseIndices.size() == operands.size()); 298 299 // Check that at least one operand is erased. 300 int firstErasedIndice = eraseIndices.find_first(); 301 if (firstErasedIndice == -1) 302 return; 303 304 // Shift all of the removed operands to the end, and destroy them. 305 numOperands = firstErasedIndice; 306 for (unsigned i = firstErasedIndice + 1, e = operands.size(); i < e; ++i) 307 if (!eraseIndices.test(i)) 308 operands[numOperands++] = std::move(operands[i]); 309 for (OpOperand &operand : operands.drop_front(numOperands)) 310 operand.~OpOperand(); 311 } 312 313 /// Resize the storage to the given size. Returns the array containing the new 314 /// operands. 315 MutableArrayRef<OpOperand> detail::OperandStorage::resize(Operation *owner, 316 unsigned newSize) { 317 // If the number of operands is less than or equal to the current amount, we 318 // can just update in place. 319 MutableArrayRef<OpOperand> origOperands = getOperands(); 320 if (newSize <= numOperands) { 321 // If the number of new size is less than the current, remove any extra 322 // operands. 323 for (unsigned i = newSize; i != numOperands; ++i) 324 origOperands[i].~OpOperand(); 325 numOperands = newSize; 326 return origOperands.take_front(newSize); 327 } 328 329 // If the new size is within the original inline capacity, grow inplace. 330 if (newSize <= capacity) { 331 OpOperand *opBegin = origOperands.data(); 332 for (unsigned e = newSize; numOperands != e; ++numOperands) 333 new (&opBegin[numOperands]) OpOperand(owner); 334 return MutableArrayRef<OpOperand>(opBegin, newSize); 335 } 336 337 // Otherwise, we need to allocate a new storage. 338 unsigned newCapacity = 339 std::max(unsigned(llvm::NextPowerOf2(capacity + 2)), newSize); 340 OpOperand *newOperandStorage = 341 reinterpret_cast<OpOperand *>(malloc(sizeof(OpOperand) * newCapacity)); 342 343 // Move the current operands to the new storage. 344 MutableArrayRef<OpOperand> newOperands(newOperandStorage, newSize); 345 std::uninitialized_copy(std::make_move_iterator(origOperands.begin()), 346 std::make_move_iterator(origOperands.end()), 347 newOperands.begin()); 348 349 // Destroy the original operands. 350 for (auto &operand : origOperands) 351 operand.~OpOperand(); 352 353 // Initialize any new operands. 354 for (unsigned e = newSize; numOperands != e; ++numOperands) 355 new (&newOperands[numOperands]) OpOperand(owner); 356 357 // If the current storage is dynamic, free it. 358 if (isStorageDynamic) 359 free(operandStorage); 360 361 // Update the storage representation to use the new dynamic storage. 362 operandStorage = newOperandStorage; 363 capacity = newCapacity; 364 isStorageDynamic = true; 365 return newOperands; 366 } 367 368 //===----------------------------------------------------------------------===// 369 // Operation Value-Iterators 370 //===----------------------------------------------------------------------===// 371 372 //===----------------------------------------------------------------------===// 373 // OperandRange 374 375 unsigned OperandRange::getBeginOperandIndex() const { 376 assert(!empty() && "range must not be empty"); 377 return base->getOperandNumber(); 378 } 379 380 OperandRangeRange OperandRange::split(ElementsAttr segmentSizes) const { 381 return OperandRangeRange(*this, segmentSizes); 382 } 383 384 //===----------------------------------------------------------------------===// 385 // OperandRangeRange 386 387 OperandRangeRange::OperandRangeRange(OperandRange operands, 388 Attribute operandSegments) 389 : OperandRangeRange(OwnerT(operands.getBase(), operandSegments), 0, 390 operandSegments.cast<DenseElementsAttr>().size()) {} 391 392 OperandRange OperandRangeRange::join() const { 393 const OwnerT &owner = getBase(); 394 auto sizeData = owner.second.cast<DenseElementsAttr>().getValues<uint32_t>(); 395 return OperandRange(owner.first, 396 std::accumulate(sizeData.begin(), sizeData.end(), 0)); 397 } 398 399 OperandRange OperandRangeRange::dereference(const OwnerT &object, 400 ptrdiff_t index) { 401 auto sizeData = object.second.cast<DenseElementsAttr>().getValues<uint32_t>(); 402 uint32_t startIndex = 403 std::accumulate(sizeData.begin(), sizeData.begin() + index, 0); 404 return OperandRange(object.first + startIndex, *(sizeData.begin() + index)); 405 } 406 407 //===----------------------------------------------------------------------===// 408 // MutableOperandRange 409 410 /// Construct a new mutable range from the given operand, operand start index, 411 /// and range length. 412 MutableOperandRange::MutableOperandRange( 413 Operation *owner, unsigned start, unsigned length, 414 ArrayRef<OperandSegment> operandSegments) 415 : owner(owner), start(start), length(length), 416 operandSegments(operandSegments.begin(), operandSegments.end()) { 417 assert((start + length) <= owner->getNumOperands() && "invalid range"); 418 } 419 MutableOperandRange::MutableOperandRange(Operation *owner) 420 : MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {} 421 422 /// Slice this range into a sub range, with the additional operand segment. 423 MutableOperandRange 424 MutableOperandRange::slice(unsigned subStart, unsigned subLen, 425 Optional<OperandSegment> segment) const { 426 assert((subStart + subLen) <= length && "invalid sub-range"); 427 MutableOperandRange subSlice(owner, start + subStart, subLen, 428 operandSegments); 429 if (segment) 430 subSlice.operandSegments.push_back(*segment); 431 return subSlice; 432 } 433 434 /// Append the given values to the range. 435 void MutableOperandRange::append(ValueRange values) { 436 if (values.empty()) 437 return; 438 owner->insertOperands(start + length, values); 439 updateLength(length + values.size()); 440 } 441 442 /// Assign this range to the given values. 443 void MutableOperandRange::assign(ValueRange values) { 444 owner->setOperands(start, length, values); 445 if (length != values.size()) 446 updateLength(/*newLength=*/values.size()); 447 } 448 449 /// Assign the range to the given value. 450 void MutableOperandRange::assign(Value value) { 451 if (length == 1) { 452 owner->setOperand(start, value); 453 } else { 454 owner->setOperands(start, length, value); 455 updateLength(/*newLength=*/1); 456 } 457 } 458 459 /// Erase the operands within the given sub-range. 460 void MutableOperandRange::erase(unsigned subStart, unsigned subLen) { 461 assert((subStart + subLen) <= length && "invalid sub-range"); 462 if (length == 0) 463 return; 464 owner->eraseOperands(start + subStart, subLen); 465 updateLength(length - subLen); 466 } 467 468 /// Clear this range and erase all of the operands. 469 void MutableOperandRange::clear() { 470 if (length != 0) { 471 owner->eraseOperands(start, length); 472 updateLength(/*newLength=*/0); 473 } 474 } 475 476 /// Allow implicit conversion to an OperandRange. 477 MutableOperandRange::operator OperandRange() const { 478 return owner->getOperands().slice(start, length); 479 } 480 481 MutableOperandRangeRange 482 MutableOperandRange::split(NamedAttribute segmentSizes) const { 483 return MutableOperandRangeRange(*this, segmentSizes); 484 } 485 486 /// Update the length of this range to the one provided. 487 void MutableOperandRange::updateLength(unsigned newLength) { 488 int32_t diff = int32_t(newLength) - int32_t(length); 489 length = newLength; 490 491 // Update any of the provided segment attributes. 492 for (OperandSegment &segment : operandSegments) { 493 auto attr = segment.second.getValue().cast<DenseIntElementsAttr>(); 494 SmallVector<int32_t, 8> segments(attr.getValues<int32_t>()); 495 segments[segment.first] += diff; 496 segment.second.setValue( 497 DenseIntElementsAttr::get(attr.getType(), segments)); 498 owner->setAttr(segment.second.getName(), segment.second.getValue()); 499 } 500 } 501 502 //===----------------------------------------------------------------------===// 503 // MutableOperandRangeRange 504 505 MutableOperandRangeRange::MutableOperandRangeRange( 506 const MutableOperandRange &operands, NamedAttribute operandSegmentAttr) 507 : MutableOperandRangeRange( 508 OwnerT(operands, operandSegmentAttr), 0, 509 operandSegmentAttr.getValue().cast<DenseElementsAttr>().size()) {} 510 511 MutableOperandRange MutableOperandRangeRange::join() const { 512 return getBase().first; 513 } 514 515 MutableOperandRangeRange::operator OperandRangeRange() const { 516 return OperandRangeRange( 517 getBase().first, getBase().second.getValue().cast<DenseElementsAttr>()); 518 } 519 520 MutableOperandRange MutableOperandRangeRange::dereference(const OwnerT &object, 521 ptrdiff_t index) { 522 auto sizeData = 523 object.second.getValue().cast<DenseElementsAttr>().getValues<uint32_t>(); 524 uint32_t startIndex = 525 std::accumulate(sizeData.begin(), sizeData.begin() + index, 0); 526 return object.first.slice( 527 startIndex, *(sizeData.begin() + index), 528 MutableOperandRange::OperandSegment(index, object.second)); 529 } 530 531 //===----------------------------------------------------------------------===// 532 // ResultRange 533 534 ResultRange::ResultRange(OpResult result) 535 : ResultRange(static_cast<detail::OpResultImpl *>(Value(result).getImpl()), 536 1) {} 537 538 ResultRange::use_range ResultRange::getUses() const { 539 return {use_begin(), use_end()}; 540 } 541 ResultRange::use_iterator ResultRange::use_begin() const { 542 return use_iterator(*this); 543 } 544 ResultRange::use_iterator ResultRange::use_end() const { 545 return use_iterator(*this, /*end=*/true); 546 } 547 ResultRange::user_range ResultRange::getUsers() { 548 return {user_begin(), user_end()}; 549 } 550 ResultRange::user_iterator ResultRange::user_begin() { 551 return user_iterator(use_begin()); 552 } 553 ResultRange::user_iterator ResultRange::user_end() { 554 return user_iterator(use_end()); 555 } 556 557 ResultRange::UseIterator::UseIterator(ResultRange results, bool end) 558 : it(end ? results.end() : results.begin()), endIt(results.end()) { 559 // Only initialize current use if there are results/can be uses. 560 if (it != endIt) 561 skipOverResultsWithNoUsers(); 562 } 563 564 ResultRange::UseIterator &ResultRange::UseIterator::operator++() { 565 // We increment over uses, if we reach the last use then move to next 566 // result. 567 if (use != (*it).use_end()) 568 ++use; 569 if (use == (*it).use_end()) { 570 ++it; 571 skipOverResultsWithNoUsers(); 572 } 573 return *this; 574 } 575 576 void ResultRange::UseIterator::skipOverResultsWithNoUsers() { 577 while (it != endIt && (*it).use_empty()) 578 ++it; 579 580 // If we are at the last result, then set use to first use of 581 // first result (sentinel value used for end). 582 if (it == endIt) 583 use = {}; 584 else 585 use = (*it).use_begin(); 586 } 587 588 void ResultRange::replaceAllUsesWith(Operation *op) { 589 replaceAllUsesWith(op->getResults()); 590 } 591 592 //===----------------------------------------------------------------------===// 593 // ValueRange 594 595 ValueRange::ValueRange(ArrayRef<Value> values) 596 : ValueRange(values.data(), values.size()) {} 597 ValueRange::ValueRange(OperandRange values) 598 : ValueRange(values.begin().getBase(), values.size()) {} 599 ValueRange::ValueRange(ResultRange values) 600 : ValueRange(values.getBase(), values.size()) {} 601 602 /// See `llvm::detail::indexed_accessor_range_base` for details. 603 ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner, 604 ptrdiff_t index) { 605 if (const auto *value = owner.dyn_cast<const Value *>()) 606 return {value + index}; 607 if (auto *operand = owner.dyn_cast<OpOperand *>()) 608 return {operand + index}; 609 return owner.get<detail::OpResultImpl *>()->getNextResultAtOffset(index); 610 } 611 /// See `llvm::detail::indexed_accessor_range_base` for details. 612 Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) { 613 if (const auto *value = owner.dyn_cast<const Value *>()) 614 return value[index]; 615 if (auto *operand = owner.dyn_cast<OpOperand *>()) 616 return operand[index].get(); 617 return owner.get<detail::OpResultImpl *>()->getNextResultAtOffset(index); 618 } 619 620 //===----------------------------------------------------------------------===// 621 // Operation Equivalency 622 //===----------------------------------------------------------------------===// 623 624 llvm::hash_code OperationEquivalence::computeHash( 625 Operation *op, function_ref<llvm::hash_code(Value)> hashOperands, 626 function_ref<llvm::hash_code(Value)> hashResults, Flags flags) { 627 // Hash operations based upon their: 628 // - Operation Name 629 // - Attributes 630 // - Result Types 631 llvm::hash_code hash = llvm::hash_combine( 632 op->getName(), op->getAttrDictionary(), op->getResultTypes()); 633 634 // - Operands 635 ValueRange operands = op->getOperands(); 636 SmallVector<Value> operandStorage; 637 if (op->hasTrait<mlir::OpTrait::IsCommutative>()) { 638 operandStorage.append(operands.begin(), operands.end()); 639 llvm::sort(operandStorage, [](Value a, Value b) -> bool { 640 return a.getAsOpaquePointer() < b.getAsOpaquePointer(); 641 }); 642 operands = operandStorage; 643 } 644 for (Value operand : operands) 645 hash = llvm::hash_combine(hash, hashOperands(operand)); 646 647 // - Operands 648 for (Value result : op->getResults()) 649 hash = llvm::hash_combine(hash, hashResults(result)); 650 return hash; 651 } 652 653 static bool 654 isRegionEquivalentTo(Region *lhs, Region *rhs, 655 function_ref<LogicalResult(Value, Value)> mapOperands, 656 function_ref<LogicalResult(Value, Value)> mapResults, 657 OperationEquivalence::Flags flags) { 658 DenseMap<Block *, Block *> blocksMap; 659 auto blocksEquivalent = [&](Block &lBlock, Block &rBlock) { 660 // Check block arguments. 661 if (lBlock.getNumArguments() != rBlock.getNumArguments()) 662 return false; 663 664 // Map the two blocks. 665 auto insertion = blocksMap.insert({&lBlock, &rBlock}); 666 if (insertion.first->getSecond() != &rBlock) 667 return false; 668 669 for (auto argPair : 670 llvm::zip(lBlock.getArguments(), rBlock.getArguments())) { 671 Value curArg = std::get<0>(argPair); 672 Value otherArg = std::get<1>(argPair); 673 if (curArg.getType() != otherArg.getType()) 674 return false; 675 if (!(flags & OperationEquivalence::IgnoreLocations) && 676 curArg.getLoc() != otherArg.getLoc()) 677 return false; 678 // Check if this value was already mapped to another value. 679 if (failed(mapOperands(curArg, otherArg))) 680 return false; 681 } 682 683 auto opsEquivalent = [&](Operation &lOp, Operation &rOp) { 684 // Check for op equality (recursively). 685 if (!OperationEquivalence::isEquivalentTo(&lOp, &rOp, mapOperands, 686 mapResults, flags)) 687 return false; 688 // Check successor mapping. 689 for (auto successorsPair : 690 llvm::zip(lOp.getSuccessors(), rOp.getSuccessors())) { 691 Block *curSuccessor = std::get<0>(successorsPair); 692 Block *otherSuccessor = std::get<1>(successorsPair); 693 auto insertion = blocksMap.insert({curSuccessor, otherSuccessor}); 694 if (insertion.first->getSecond() != otherSuccessor) 695 return false; 696 } 697 return true; 698 }; 699 return llvm::all_of_zip(lBlock, rBlock, opsEquivalent); 700 }; 701 return llvm::all_of_zip(*lhs, *rhs, blocksEquivalent); 702 } 703 704 bool OperationEquivalence::isEquivalentTo( 705 Operation *lhs, Operation *rhs, 706 function_ref<LogicalResult(Value, Value)> mapOperands, 707 function_ref<LogicalResult(Value, Value)> mapResults, Flags flags) { 708 if (lhs == rhs) 709 return true; 710 711 // Compare the operation properties. 712 if (lhs->getName() != rhs->getName() || 713 lhs->getAttrDictionary() != rhs->getAttrDictionary() || 714 lhs->getNumRegions() != rhs->getNumRegions() || 715 lhs->getNumSuccessors() != rhs->getNumSuccessors() || 716 lhs->getNumOperands() != rhs->getNumOperands() || 717 lhs->getNumResults() != rhs->getNumResults()) 718 return false; 719 if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc()) 720 return false; 721 722 ValueRange lhsOperands = lhs->getOperands(), rhsOperands = rhs->getOperands(); 723 SmallVector<Value> lhsOperandStorage, rhsOperandStorage; 724 if (lhs->hasTrait<mlir::OpTrait::IsCommutative>()) { 725 lhsOperandStorage.append(lhsOperands.begin(), lhsOperands.end()); 726 llvm::sort(lhsOperandStorage, [](Value a, Value b) -> bool { 727 return a.getAsOpaquePointer() < b.getAsOpaquePointer(); 728 }); 729 lhsOperands = lhsOperandStorage; 730 731 rhsOperandStorage.append(rhsOperands.begin(), rhsOperands.end()); 732 llvm::sort(rhsOperandStorage, [](Value a, Value b) -> bool { 733 return a.getAsOpaquePointer() < b.getAsOpaquePointer(); 734 }); 735 rhsOperands = rhsOperandStorage; 736 } 737 auto checkValueRangeMapping = 738 [](ValueRange lhs, ValueRange rhs, 739 function_ref<LogicalResult(Value, Value)> mapValues) { 740 for (auto operandPair : llvm::zip(lhs, rhs)) { 741 Value curArg = std::get<0>(operandPair); 742 Value otherArg = std::get<1>(operandPair); 743 if (curArg.getType() != otherArg.getType()) 744 return false; 745 if (failed(mapValues(curArg, otherArg))) 746 return false; 747 } 748 return true; 749 }; 750 // Check mapping of operands and results. 751 if (!checkValueRangeMapping(lhsOperands, rhsOperands, mapOperands)) 752 return false; 753 if (!checkValueRangeMapping(lhs->getResults(), rhs->getResults(), mapResults)) 754 return false; 755 for (auto regionPair : llvm::zip(lhs->getRegions(), rhs->getRegions())) 756 if (!isRegionEquivalentTo(&std::get<0>(regionPair), 757 &std::get<1>(regionPair), mapOperands, mapResults, 758 flags)) 759 return false; 760 return true; 761 } 762