1 //===- OperationSupport.cpp -----------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains out-of-line implementations of the support types that 10 // Operation and related classes build on top of. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/IR/OperationSupport.h" 15 #include "mlir/IR/BuiltinAttributes.h" 16 #include "mlir/IR/BuiltinTypes.h" 17 #include "mlir/IR/OpDefinition.h" 18 #include "llvm/ADT/BitVector.h" 19 #include <numeric> 20 21 using namespace mlir; 22 23 //===----------------------------------------------------------------------===// 24 // NamedAttrList 25 //===----------------------------------------------------------------------===// 26 27 NamedAttrList::NamedAttrList(ArrayRef<NamedAttribute> attributes) { 28 assign(attributes.begin(), attributes.end()); 29 } 30 31 NamedAttrList::NamedAttrList(DictionaryAttr attributes) 32 : NamedAttrList(attributes ? attributes.getValue() 33 : ArrayRef<NamedAttribute>()) { 34 dictionarySorted.setPointerAndInt(attributes, true); 35 } 36 37 NamedAttrList::NamedAttrList(const_iterator inStart, const_iterator inEnd) { 38 assign(inStart, inEnd); 39 } 40 41 ArrayRef<NamedAttribute> NamedAttrList::getAttrs() const { return attrs; } 42 43 Optional<NamedAttribute> NamedAttrList::findDuplicate() const { 44 Optional<NamedAttribute> duplicate = 45 DictionaryAttr::findDuplicate(attrs, isSorted()); 46 // DictionaryAttr::findDuplicate will sort the list, so reset the sorted 47 // state. 48 if (!isSorted()) 49 dictionarySorted.setPointerAndInt(nullptr, true); 50 return duplicate; 51 } 52 53 DictionaryAttr NamedAttrList::getDictionary(MLIRContext *context) const { 54 if (!isSorted()) { 55 DictionaryAttr::sortInPlace(attrs); 56 dictionarySorted.setPointerAndInt(nullptr, true); 57 } 58 if (!dictionarySorted.getPointer()) 59 dictionarySorted.setPointer(DictionaryAttr::getWithSorted(context, attrs)); 60 return dictionarySorted.getPointer().cast<DictionaryAttr>(); 61 } 62 63 /// Add an attribute with the specified name. 64 void NamedAttrList::append(StringRef name, Attribute attr) { 65 append(StringAttr::get(attr.getContext(), name), attr); 66 } 67 68 /// Replaces the attributes with new list of attributes. 69 void NamedAttrList::assign(const_iterator inStart, const_iterator inEnd) { 70 DictionaryAttr::sort(ArrayRef<NamedAttribute>{inStart, inEnd}, attrs); 71 dictionarySorted.setPointerAndInt(nullptr, true); 72 } 73 74 void NamedAttrList::push_back(NamedAttribute newAttribute) { 75 if (isSorted()) 76 dictionarySorted.setInt(attrs.empty() || attrs.back() < newAttribute); 77 dictionarySorted.setPointer(nullptr); 78 attrs.push_back(newAttribute); 79 } 80 81 /// Return the specified attribute if present, null otherwise. 82 Attribute NamedAttrList::get(StringRef name) const { 83 auto it = findAttr(*this, name); 84 return it.second ? it.first->getValue() : Attribute(); 85 } 86 Attribute NamedAttrList::get(StringAttr name) const { 87 auto it = findAttr(*this, name); 88 return it.second ? it.first->getValue() : Attribute(); 89 } 90 91 /// Return the specified named attribute if present, None otherwise. 92 Optional<NamedAttribute> NamedAttrList::getNamed(StringRef name) const { 93 auto it = findAttr(*this, name); 94 return it.second ? *it.first : Optional<NamedAttribute>(); 95 } 96 Optional<NamedAttribute> NamedAttrList::getNamed(StringAttr name) const { 97 auto it = findAttr(*this, name); 98 return it.second ? *it.first : Optional<NamedAttribute>(); 99 } 100 101 /// If the an attribute exists with the specified name, change it to the new 102 /// value. Otherwise, add a new attribute with the specified name/value. 103 Attribute NamedAttrList::set(StringAttr name, Attribute value) { 104 assert(value && "attributes may never be null"); 105 106 // Look for an existing attribute with the given name, and set its value 107 // in-place. Return the previous value of the attribute, if there was one. 108 auto it = findAttr(*this, name); 109 if (it.second) { 110 // Update the existing attribute by swapping out the old value for the new 111 // value. Return the old value. 112 Attribute oldValue = it.first->getValue(); 113 if (it.first->getValue() != value) { 114 it.first->setValue(value); 115 116 // If the attributes have changed, the dictionary is invalidated. 117 dictionarySorted.setPointer(nullptr); 118 } 119 return oldValue; 120 } 121 // Perform a string lookup to insert the new attribute into its sorted 122 // position. 123 if (isSorted()) 124 it = findAttr(*this, name.strref()); 125 attrs.insert(it.first, {name, value}); 126 // Invalidate the dictionary. Return null as there was no previous value. 127 dictionarySorted.setPointer(nullptr); 128 return Attribute(); 129 } 130 131 Attribute NamedAttrList::set(StringRef name, Attribute value) { 132 assert(value && "attributes may never be null"); 133 return set(mlir::StringAttr::get(value.getContext(), name), value); 134 } 135 136 Attribute 137 NamedAttrList::eraseImpl(SmallVectorImpl<NamedAttribute>::iterator it) { 138 // Erasing does not affect the sorted property. 139 Attribute attr = it->getValue(); 140 attrs.erase(it); 141 dictionarySorted.setPointer(nullptr); 142 return attr; 143 } 144 145 Attribute NamedAttrList::erase(StringAttr name) { 146 auto it = findAttr(*this, name); 147 return it.second ? eraseImpl(it.first) : Attribute(); 148 } 149 150 Attribute NamedAttrList::erase(StringRef name) { 151 auto it = findAttr(*this, name); 152 return it.second ? eraseImpl(it.first) : Attribute(); 153 } 154 155 NamedAttrList & 156 NamedAttrList::operator=(const SmallVectorImpl<NamedAttribute> &rhs) { 157 assign(rhs.begin(), rhs.end()); 158 return *this; 159 } 160 161 NamedAttrList::operator ArrayRef<NamedAttribute>() const { return attrs; } 162 163 //===----------------------------------------------------------------------===// 164 // OperationState 165 //===----------------------------------------------------------------------===// 166 167 OperationState::OperationState(Location location, StringRef name) 168 : location(location), name(name, location->getContext()) {} 169 170 OperationState::OperationState(Location location, OperationName name) 171 : location(location), name(name) {} 172 173 OperationState::OperationState(Location location, OperationName name, 174 ValueRange operands, TypeRange types, 175 ArrayRef<NamedAttribute> attributes, 176 BlockRange successors, 177 MutableArrayRef<std::unique_ptr<Region>> regions) 178 : location(location), name(name), 179 operands(operands.begin(), operands.end()), 180 types(types.begin(), types.end()), 181 attributes(attributes.begin(), attributes.end()), 182 successors(successors.begin(), successors.end()) { 183 for (std::unique_ptr<Region> &r : regions) 184 this->regions.push_back(std::move(r)); 185 } 186 OperationState::OperationState(Location location, StringRef name, 187 ValueRange operands, TypeRange types, 188 ArrayRef<NamedAttribute> attributes, 189 BlockRange successors, 190 MutableArrayRef<std::unique_ptr<Region>> regions) 191 : OperationState(location, OperationName(name, location.getContext()), 192 operands, types, attributes, successors, regions) {} 193 194 void OperationState::addOperands(ValueRange newOperands) { 195 operands.append(newOperands.begin(), newOperands.end()); 196 } 197 198 void OperationState::addSuccessors(BlockRange newSuccessors) { 199 successors.append(newSuccessors.begin(), newSuccessors.end()); 200 } 201 202 Region *OperationState::addRegion() { 203 regions.emplace_back(new Region); 204 return regions.back().get(); 205 } 206 207 void OperationState::addRegion(std::unique_ptr<Region> &®ion) { 208 regions.push_back(std::move(region)); 209 } 210 211 void OperationState::addRegions( 212 MutableArrayRef<std::unique_ptr<Region>> regions) { 213 for (std::unique_ptr<Region> ®ion : regions) 214 addRegion(std::move(region)); 215 } 216 217 //===----------------------------------------------------------------------===// 218 // OperandStorage 219 //===----------------------------------------------------------------------===// 220 221 detail::OperandStorage::OperandStorage(Operation *owner, 222 OpOperand *trailingOperands, 223 ValueRange values) 224 : isStorageDynamic(false), operandStorage(trailingOperands) { 225 numOperands = capacity = values.size(); 226 for (unsigned i = 0; i < numOperands; ++i) 227 new (&operandStorage[i]) OpOperand(owner, values[i]); 228 } 229 230 detail::OperandStorage::~OperandStorage() { 231 for (auto &operand : getOperands()) 232 operand.~OpOperand(); 233 234 // If the storage is dynamic, deallocate it. 235 if (isStorageDynamic) 236 free(operandStorage); 237 } 238 239 /// Replace the operands contained in the storage with the ones provided in 240 /// 'values'. 241 void detail::OperandStorage::setOperands(Operation *owner, ValueRange values) { 242 MutableArrayRef<OpOperand> storageOperands = resize(owner, values.size()); 243 for (unsigned i = 0, e = values.size(); i != e; ++i) 244 storageOperands[i].set(values[i]); 245 } 246 247 /// Replace the operands beginning at 'start' and ending at 'start' + 'length' 248 /// with the ones provided in 'operands'. 'operands' may be smaller or larger 249 /// than the range pointed to by 'start'+'length'. 250 void detail::OperandStorage::setOperands(Operation *owner, unsigned start, 251 unsigned length, ValueRange operands) { 252 // If the new size is the same, we can update inplace. 253 unsigned newSize = operands.size(); 254 if (newSize == length) { 255 MutableArrayRef<OpOperand> storageOperands = getOperands(); 256 for (unsigned i = 0, e = length; i != e; ++i) 257 storageOperands[start + i].set(operands[i]); 258 return; 259 } 260 // If the new size is greater, remove the extra operands and set the rest 261 // inplace. 262 if (newSize < length) { 263 eraseOperands(start + operands.size(), length - newSize); 264 setOperands(owner, start, newSize, operands); 265 return; 266 } 267 // Otherwise, the new size is greater so we need to grow the storage. 268 auto storageOperands = resize(owner, size() + (newSize - length)); 269 270 // Shift operands to the right to make space for the new operands. 271 unsigned rotateSize = storageOperands.size() - (start + length); 272 auto rbegin = storageOperands.rbegin(); 273 std::rotate(rbegin, std::next(rbegin, newSize - length), rbegin + rotateSize); 274 275 // Update the operands inplace. 276 for (unsigned i = 0, e = operands.size(); i != e; ++i) 277 storageOperands[start + i].set(operands[i]); 278 } 279 280 /// Erase an operand held by the storage. 281 void detail::OperandStorage::eraseOperands(unsigned start, unsigned length) { 282 MutableArrayRef<OpOperand> operands = getOperands(); 283 assert((start + length) <= operands.size()); 284 numOperands -= length; 285 286 // Shift all operands down if the operand to remove is not at the end. 287 if (start != numOperands) { 288 auto *indexIt = std::next(operands.begin(), start); 289 std::rotate(indexIt, std::next(indexIt, length), operands.end()); 290 } 291 for (unsigned i = 0; i != length; ++i) 292 operands[numOperands + i].~OpOperand(); 293 } 294 295 void detail::OperandStorage::eraseOperands( 296 const BitVector &eraseIndices) { 297 MutableArrayRef<OpOperand> operands = getOperands(); 298 assert(eraseIndices.size() == operands.size()); 299 300 // Check that at least one operand is erased. 301 int firstErasedIndice = eraseIndices.find_first(); 302 if (firstErasedIndice == -1) 303 return; 304 305 // Shift all of the removed operands to the end, and destroy them. 306 numOperands = firstErasedIndice; 307 for (unsigned i = firstErasedIndice + 1, e = operands.size(); i < e; ++i) 308 if (!eraseIndices.test(i)) 309 operands[numOperands++] = std::move(operands[i]); 310 for (OpOperand &operand : operands.drop_front(numOperands)) 311 operand.~OpOperand(); 312 } 313 314 /// Resize the storage to the given size. Returns the array containing the new 315 /// operands. 316 MutableArrayRef<OpOperand> detail::OperandStorage::resize(Operation *owner, 317 unsigned newSize) { 318 // If the number of operands is less than or equal to the current amount, we 319 // can just update in place. 320 MutableArrayRef<OpOperand> origOperands = getOperands(); 321 if (newSize <= numOperands) { 322 // If the number of new size is less than the current, remove any extra 323 // operands. 324 for (unsigned i = newSize; i != numOperands; ++i) 325 origOperands[i].~OpOperand(); 326 numOperands = newSize; 327 return origOperands.take_front(newSize); 328 } 329 330 // If the new size is within the original inline capacity, grow inplace. 331 if (newSize <= capacity) { 332 OpOperand *opBegin = origOperands.data(); 333 for (unsigned e = newSize; numOperands != e; ++numOperands) 334 new (&opBegin[numOperands]) OpOperand(owner); 335 return MutableArrayRef<OpOperand>(opBegin, newSize); 336 } 337 338 // Otherwise, we need to allocate a new storage. 339 unsigned newCapacity = 340 std::max(unsigned(llvm::NextPowerOf2(capacity + 2)), newSize); 341 OpOperand *newOperandStorage = 342 reinterpret_cast<OpOperand *>(malloc(sizeof(OpOperand) * newCapacity)); 343 344 // Move the current operands to the new storage. 345 MutableArrayRef<OpOperand> newOperands(newOperandStorage, newSize); 346 std::uninitialized_copy(std::make_move_iterator(origOperands.begin()), 347 std::make_move_iterator(origOperands.end()), 348 newOperands.begin()); 349 350 // Destroy the original operands. 351 for (auto &operand : origOperands) 352 operand.~OpOperand(); 353 354 // Initialize any new operands. 355 for (unsigned e = newSize; numOperands != e; ++numOperands) 356 new (&newOperands[numOperands]) OpOperand(owner); 357 358 // If the current storage is dynamic, free it. 359 if (isStorageDynamic) 360 free(operandStorage); 361 362 // Update the storage representation to use the new dynamic storage. 363 operandStorage = newOperandStorage; 364 capacity = newCapacity; 365 isStorageDynamic = true; 366 return newOperands; 367 } 368 369 //===----------------------------------------------------------------------===// 370 // Operation Value-Iterators 371 //===----------------------------------------------------------------------===// 372 373 //===----------------------------------------------------------------------===// 374 // OperandRange 375 376 unsigned OperandRange::getBeginOperandIndex() const { 377 assert(!empty() && "range must not be empty"); 378 return base->getOperandNumber(); 379 } 380 381 OperandRangeRange OperandRange::split(ElementsAttr segmentSizes) const { 382 return OperandRangeRange(*this, segmentSizes); 383 } 384 385 //===----------------------------------------------------------------------===// 386 // OperandRangeRange 387 388 OperandRangeRange::OperandRangeRange(OperandRange operands, 389 Attribute operandSegments) 390 : OperandRangeRange(OwnerT(operands.getBase(), operandSegments), 0, 391 operandSegments.cast<DenseElementsAttr>().size()) {} 392 393 OperandRange OperandRangeRange::join() const { 394 const OwnerT &owner = getBase(); 395 auto sizeData = owner.second.cast<DenseElementsAttr>().getValues<uint32_t>(); 396 return OperandRange(owner.first, 397 std::accumulate(sizeData.begin(), sizeData.end(), 0)); 398 } 399 400 OperandRange OperandRangeRange::dereference(const OwnerT &object, 401 ptrdiff_t index) { 402 auto sizeData = object.second.cast<DenseElementsAttr>().getValues<uint32_t>(); 403 uint32_t startIndex = 404 std::accumulate(sizeData.begin(), sizeData.begin() + index, 0); 405 return OperandRange(object.first + startIndex, *(sizeData.begin() + index)); 406 } 407 408 //===----------------------------------------------------------------------===// 409 // MutableOperandRange 410 411 /// Construct a new mutable range from the given operand, operand start index, 412 /// and range length. 413 MutableOperandRange::MutableOperandRange( 414 Operation *owner, unsigned start, unsigned length, 415 ArrayRef<OperandSegment> operandSegments) 416 : owner(owner), start(start), length(length), 417 operandSegments(operandSegments.begin(), operandSegments.end()) { 418 assert((start + length) <= owner->getNumOperands() && "invalid range"); 419 } 420 MutableOperandRange::MutableOperandRange(Operation *owner) 421 : MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {} 422 423 /// Slice this range into a sub range, with the additional operand segment. 424 MutableOperandRange 425 MutableOperandRange::slice(unsigned subStart, unsigned subLen, 426 Optional<OperandSegment> segment) const { 427 assert((subStart + subLen) <= length && "invalid sub-range"); 428 MutableOperandRange subSlice(owner, start + subStart, subLen, 429 operandSegments); 430 if (segment) 431 subSlice.operandSegments.push_back(*segment); 432 return subSlice; 433 } 434 435 /// Append the given values to the range. 436 void MutableOperandRange::append(ValueRange values) { 437 if (values.empty()) 438 return; 439 owner->insertOperands(start + length, values); 440 updateLength(length + values.size()); 441 } 442 443 /// Assign this range to the given values. 444 void MutableOperandRange::assign(ValueRange values) { 445 owner->setOperands(start, length, values); 446 if (length != values.size()) 447 updateLength(/*newLength=*/values.size()); 448 } 449 450 /// Assign the range to the given value. 451 void MutableOperandRange::assign(Value value) { 452 if (length == 1) { 453 owner->setOperand(start, value); 454 } else { 455 owner->setOperands(start, length, value); 456 updateLength(/*newLength=*/1); 457 } 458 } 459 460 /// Erase the operands within the given sub-range. 461 void MutableOperandRange::erase(unsigned subStart, unsigned subLen) { 462 assert((subStart + subLen) <= length && "invalid sub-range"); 463 if (length == 0) 464 return; 465 owner->eraseOperands(start + subStart, subLen); 466 updateLength(length - subLen); 467 } 468 469 /// Clear this range and erase all of the operands. 470 void MutableOperandRange::clear() { 471 if (length != 0) { 472 owner->eraseOperands(start, length); 473 updateLength(/*newLength=*/0); 474 } 475 } 476 477 /// Allow implicit conversion to an OperandRange. 478 MutableOperandRange::operator OperandRange() const { 479 return owner->getOperands().slice(start, length); 480 } 481 482 MutableOperandRangeRange 483 MutableOperandRange::split(NamedAttribute segmentSizes) const { 484 return MutableOperandRangeRange(*this, segmentSizes); 485 } 486 487 /// Update the length of this range to the one provided. 488 void MutableOperandRange::updateLength(unsigned newLength) { 489 int32_t diff = int32_t(newLength) - int32_t(length); 490 length = newLength; 491 492 // Update any of the provided segment attributes. 493 for (OperandSegment &segment : operandSegments) { 494 auto attr = segment.second.getValue().cast<DenseIntElementsAttr>(); 495 SmallVector<int32_t, 8> segments(attr.getValues<int32_t>()); 496 segments[segment.first] += diff; 497 segment.second.setValue( 498 DenseIntElementsAttr::get(attr.getType(), segments)); 499 owner->setAttr(segment.second.getName(), segment.second.getValue()); 500 } 501 } 502 503 //===----------------------------------------------------------------------===// 504 // MutableOperandRangeRange 505 506 MutableOperandRangeRange::MutableOperandRangeRange( 507 const MutableOperandRange &operands, NamedAttribute operandSegmentAttr) 508 : MutableOperandRangeRange( 509 OwnerT(operands, operandSegmentAttr), 0, 510 operandSegmentAttr.getValue().cast<DenseElementsAttr>().size()) {} 511 512 MutableOperandRange MutableOperandRangeRange::join() const { 513 return getBase().first; 514 } 515 516 MutableOperandRangeRange::operator OperandRangeRange() const { 517 return OperandRangeRange( 518 getBase().first, getBase().second.getValue().cast<DenseElementsAttr>()); 519 } 520 521 MutableOperandRange MutableOperandRangeRange::dereference(const OwnerT &object, 522 ptrdiff_t index) { 523 auto sizeData = 524 object.second.getValue().cast<DenseElementsAttr>().getValues<uint32_t>(); 525 uint32_t startIndex = 526 std::accumulate(sizeData.begin(), sizeData.begin() + index, 0); 527 return object.first.slice( 528 startIndex, *(sizeData.begin() + index), 529 MutableOperandRange::OperandSegment(index, object.second)); 530 } 531 532 //===----------------------------------------------------------------------===// 533 // ResultRange 534 535 ResultRange::ResultRange(OpResult result) 536 : ResultRange(static_cast<detail::OpResultImpl *>(Value(result).getImpl()), 537 1) {} 538 539 ResultRange::use_range ResultRange::getUses() const { 540 return {use_begin(), use_end()}; 541 } 542 ResultRange::use_iterator ResultRange::use_begin() const { 543 return use_iterator(*this); 544 } 545 ResultRange::use_iterator ResultRange::use_end() const { 546 return use_iterator(*this, /*end=*/true); 547 } 548 ResultRange::user_range ResultRange::getUsers() { 549 return {user_begin(), user_end()}; 550 } 551 ResultRange::user_iterator ResultRange::user_begin() { 552 return user_iterator(use_begin()); 553 } 554 ResultRange::user_iterator ResultRange::user_end() { 555 return user_iterator(use_end()); 556 } 557 558 ResultRange::UseIterator::UseIterator(ResultRange results, bool end) 559 : it(end ? results.end() : results.begin()), endIt(results.end()) { 560 // Only initialize current use if there are results/can be uses. 561 if (it != endIt) 562 skipOverResultsWithNoUsers(); 563 } 564 565 ResultRange::UseIterator &ResultRange::UseIterator::operator++() { 566 // We increment over uses, if we reach the last use then move to next 567 // result. 568 if (use != (*it).use_end()) 569 ++use; 570 if (use == (*it).use_end()) { 571 ++it; 572 skipOverResultsWithNoUsers(); 573 } 574 return *this; 575 } 576 577 void ResultRange::UseIterator::skipOverResultsWithNoUsers() { 578 while (it != endIt && (*it).use_empty()) 579 ++it; 580 581 // If we are at the last result, then set use to first use of 582 // first result (sentinel value used for end). 583 if (it == endIt) 584 use = {}; 585 else 586 use = (*it).use_begin(); 587 } 588 589 void ResultRange::replaceAllUsesWith(Operation *op) { 590 replaceAllUsesWith(op->getResults()); 591 } 592 593 //===----------------------------------------------------------------------===// 594 // ValueRange 595 596 ValueRange::ValueRange(ArrayRef<Value> values) 597 : ValueRange(values.data(), values.size()) {} 598 ValueRange::ValueRange(OperandRange values) 599 : ValueRange(values.begin().getBase(), values.size()) {} 600 ValueRange::ValueRange(ResultRange values) 601 : ValueRange(values.getBase(), values.size()) {} 602 603 /// See `llvm::detail::indexed_accessor_range_base` for details. 604 ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner, 605 ptrdiff_t index) { 606 if (const auto *value = owner.dyn_cast<const Value *>()) 607 return {value + index}; 608 if (auto *operand = owner.dyn_cast<OpOperand *>()) 609 return {operand + index}; 610 return owner.get<detail::OpResultImpl *>()->getNextResultAtOffset(index); 611 } 612 /// See `llvm::detail::indexed_accessor_range_base` for details. 613 Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) { 614 if (const auto *value = owner.dyn_cast<const Value *>()) 615 return value[index]; 616 if (auto *operand = owner.dyn_cast<OpOperand *>()) 617 return operand[index].get(); 618 return owner.get<detail::OpResultImpl *>()->getNextResultAtOffset(index); 619 } 620 621 //===----------------------------------------------------------------------===// 622 // Operation Equivalency 623 //===----------------------------------------------------------------------===// 624 625 llvm::hash_code OperationEquivalence::computeHash( 626 Operation *op, function_ref<llvm::hash_code(Value)> hashOperands, 627 function_ref<llvm::hash_code(Value)> hashResults, Flags flags) { 628 // Hash operations based upon their: 629 // - Operation Name 630 // - Attributes 631 // - Result Types 632 llvm::hash_code hash = llvm::hash_combine( 633 op->getName(), op->getAttrDictionary(), op->getResultTypes()); 634 635 // - Operands 636 ValueRange operands = op->getOperands(); 637 SmallVector<Value> operandStorage; 638 if (op->hasTrait<mlir::OpTrait::IsCommutative>()) { 639 operandStorage.append(operands.begin(), operands.end()); 640 llvm::sort(operandStorage, [](Value a, Value b) -> bool { 641 return a.getAsOpaquePointer() < b.getAsOpaquePointer(); 642 }); 643 operands = operandStorage; 644 } 645 for (Value operand : operands) 646 hash = llvm::hash_combine(hash, hashOperands(operand)); 647 648 // - Operands 649 for (Value result : op->getResults()) 650 hash = llvm::hash_combine(hash, hashResults(result)); 651 return hash; 652 } 653 654 static bool 655 isRegionEquivalentTo(Region *lhs, Region *rhs, 656 function_ref<LogicalResult(Value, Value)> mapOperands, 657 function_ref<LogicalResult(Value, Value)> mapResults, 658 OperationEquivalence::Flags flags) { 659 DenseMap<Block *, Block *> blocksMap; 660 auto blocksEquivalent = [&](Block &lBlock, Block &rBlock) { 661 // Check block arguments. 662 if (lBlock.getNumArguments() != rBlock.getNumArguments()) 663 return false; 664 665 // Map the two blocks. 666 auto insertion = blocksMap.insert({&lBlock, &rBlock}); 667 if (insertion.first->getSecond() != &rBlock) 668 return false; 669 670 for (auto argPair : 671 llvm::zip(lBlock.getArguments(), rBlock.getArguments())) { 672 Value curArg = std::get<0>(argPair); 673 Value otherArg = std::get<1>(argPair); 674 if (curArg.getType() != otherArg.getType()) 675 return false; 676 if (!(flags & OperationEquivalence::IgnoreLocations) && 677 curArg.getLoc() != otherArg.getLoc()) 678 return false; 679 // Check if this value was already mapped to another value. 680 if (failed(mapOperands(curArg, otherArg))) 681 return false; 682 } 683 684 auto opsEquivalent = [&](Operation &lOp, Operation &rOp) { 685 // Check for op equality (recursively). 686 if (!OperationEquivalence::isEquivalentTo(&lOp, &rOp, mapOperands, 687 mapResults, flags)) 688 return false; 689 // Check successor mapping. 690 for (auto successorsPair : 691 llvm::zip(lOp.getSuccessors(), rOp.getSuccessors())) { 692 Block *curSuccessor = std::get<0>(successorsPair); 693 Block *otherSuccessor = std::get<1>(successorsPair); 694 auto insertion = blocksMap.insert({curSuccessor, otherSuccessor}); 695 if (insertion.first->getSecond() != otherSuccessor) 696 return false; 697 } 698 return true; 699 }; 700 return llvm::all_of_zip(lBlock, rBlock, opsEquivalent); 701 }; 702 return llvm::all_of_zip(*lhs, *rhs, blocksEquivalent); 703 } 704 705 bool OperationEquivalence::isEquivalentTo( 706 Operation *lhs, Operation *rhs, 707 function_ref<LogicalResult(Value, Value)> mapOperands, 708 function_ref<LogicalResult(Value, Value)> mapResults, Flags flags) { 709 if (lhs == rhs) 710 return true; 711 712 // Compare the operation properties. 713 if (lhs->getName() != rhs->getName() || 714 lhs->getAttrDictionary() != rhs->getAttrDictionary() || 715 lhs->getNumRegions() != rhs->getNumRegions() || 716 lhs->getNumSuccessors() != rhs->getNumSuccessors() || 717 lhs->getNumOperands() != rhs->getNumOperands() || 718 lhs->getNumResults() != rhs->getNumResults()) 719 return false; 720 if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc()) 721 return false; 722 723 ValueRange lhsOperands = lhs->getOperands(), rhsOperands = rhs->getOperands(); 724 SmallVector<Value> lhsOperandStorage, rhsOperandStorage; 725 if (lhs->hasTrait<mlir::OpTrait::IsCommutative>()) { 726 lhsOperandStorage.append(lhsOperands.begin(), lhsOperands.end()); 727 llvm::sort(lhsOperandStorage, [](Value a, Value b) -> bool { 728 return a.getAsOpaquePointer() < b.getAsOpaquePointer(); 729 }); 730 lhsOperands = lhsOperandStorage; 731 732 rhsOperandStorage.append(rhsOperands.begin(), rhsOperands.end()); 733 llvm::sort(rhsOperandStorage, [](Value a, Value b) -> bool { 734 return a.getAsOpaquePointer() < b.getAsOpaquePointer(); 735 }); 736 rhsOperands = rhsOperandStorage; 737 } 738 auto checkValueRangeMapping = 739 [](ValueRange lhs, ValueRange rhs, 740 function_ref<LogicalResult(Value, Value)> mapValues) { 741 for (auto operandPair : llvm::zip(lhs, rhs)) { 742 Value curArg = std::get<0>(operandPair); 743 Value otherArg = std::get<1>(operandPair); 744 if (curArg.getType() != otherArg.getType()) 745 return false; 746 if (failed(mapValues(curArg, otherArg))) 747 return false; 748 } 749 return true; 750 }; 751 // Check mapping of operands and results. 752 if (!checkValueRangeMapping(lhsOperands, rhsOperands, mapOperands)) 753 return false; 754 if (!checkValueRangeMapping(lhs->getResults(), rhs->getResults(), mapResults)) 755 return false; 756 for (auto regionPair : llvm::zip(lhs->getRegions(), rhs->getRegions())) 757 if (!isRegionEquivalentTo(&std::get<0>(regionPair), 758 &std::get<1>(regionPair), mapOperands, mapResults, 759 flags)) 760 return false; 761 return true; 762 } 763