1 //===- OperationSupport.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains out-of-line implementations of the support types that
10 // Operation and related classes build on top of.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/IR/OperationSupport.h"
15 #include "mlir/IR/BuiltinAttributes.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "llvm/ADT/BitVector.h"
19 #include <numeric>
20 
21 using namespace mlir;
22 
23 //===----------------------------------------------------------------------===//
24 // NamedAttrList
25 //===----------------------------------------------------------------------===//
26 
27 NamedAttrList::NamedAttrList(ArrayRef<NamedAttribute> attributes) {
28   assign(attributes.begin(), attributes.end());
29 }
30 
31 NamedAttrList::NamedAttrList(DictionaryAttr attributes)
32     : NamedAttrList(attributes ? attributes.getValue()
33                                : ArrayRef<NamedAttribute>()) {
34   dictionarySorted.setPointerAndInt(attributes, true);
35 }
36 
37 NamedAttrList::NamedAttrList(const_iterator inStart, const_iterator inEnd) {
38   assign(inStart, inEnd);
39 }
40 
41 ArrayRef<NamedAttribute> NamedAttrList::getAttrs() const { return attrs; }
42 
43 Optional<NamedAttribute> NamedAttrList::findDuplicate() const {
44   Optional<NamedAttribute> duplicate =
45       DictionaryAttr::findDuplicate(attrs, isSorted());
46   // DictionaryAttr::findDuplicate will sort the list, so reset the sorted
47   // state.
48   if (!isSorted())
49     dictionarySorted.setPointerAndInt(nullptr, true);
50   return duplicate;
51 }
52 
53 DictionaryAttr NamedAttrList::getDictionary(MLIRContext *context) const {
54   if (!isSorted()) {
55     DictionaryAttr::sortInPlace(attrs);
56     dictionarySorted.setPointerAndInt(nullptr, true);
57   }
58   if (!dictionarySorted.getPointer())
59     dictionarySorted.setPointer(DictionaryAttr::getWithSorted(context, attrs));
60   return dictionarySorted.getPointer().cast<DictionaryAttr>();
61 }
62 
63 /// Add an attribute with the specified name.
64 void NamedAttrList::append(StringRef name, Attribute attr) {
65   append(StringAttr::get(attr.getContext(), name), attr);
66 }
67 
68 /// Replaces the attributes with new list of attributes.
69 void NamedAttrList::assign(const_iterator inStart, const_iterator inEnd) {
70   DictionaryAttr::sort(ArrayRef<NamedAttribute>{inStart, inEnd}, attrs);
71   dictionarySorted.setPointerAndInt(nullptr, true);
72 }
73 
74 void NamedAttrList::push_back(NamedAttribute newAttribute) {
75   if (isSorted())
76     dictionarySorted.setInt(attrs.empty() || attrs.back() < newAttribute);
77   dictionarySorted.setPointer(nullptr);
78   attrs.push_back(newAttribute);
79 }
80 
81 /// Return the specified attribute if present, null otherwise.
82 Attribute NamedAttrList::get(StringRef name) const {
83   auto it = findAttr(*this, name);
84   return it.second ? it.first->getValue() : Attribute();
85 }
86 Attribute NamedAttrList::get(StringAttr name) const {
87   auto it = findAttr(*this, name);
88   return it.second ? it.first->getValue() : Attribute();
89 }
90 
91 /// Return the specified named attribute if present, None otherwise.
92 Optional<NamedAttribute> NamedAttrList::getNamed(StringRef name) const {
93   auto it = findAttr(*this, name);
94   return it.second ? *it.first : Optional<NamedAttribute>();
95 }
96 Optional<NamedAttribute> NamedAttrList::getNamed(StringAttr name) const {
97   auto it = findAttr(*this, name);
98   return it.second ? *it.first : Optional<NamedAttribute>();
99 }
100 
101 /// If the an attribute exists with the specified name, change it to the new
102 /// value.  Otherwise, add a new attribute with the specified name/value.
103 Attribute NamedAttrList::set(StringAttr name, Attribute value) {
104   assert(value && "attributes may never be null");
105 
106   // Look for an existing attribute with the given name, and set its value
107   // in-place. Return the previous value of the attribute, if there was one.
108   auto it = findAttr(*this, name);
109   if (it.second) {
110     // Update the existing attribute by swapping out the old value for the new
111     // value. Return the old value.
112     Attribute oldValue = it.first->getValue();
113     if (it.first->getValue() != value) {
114       it.first->setValue(value);
115 
116       // If the attributes have changed, the dictionary is invalidated.
117       dictionarySorted.setPointer(nullptr);
118     }
119     return oldValue;
120   }
121   // Perform a string lookup to insert the new attribute into its sorted
122   // position.
123   if (isSorted())
124     it = findAttr(*this, name.strref());
125   attrs.insert(it.first, {name, value});
126   // Invalidate the dictionary. Return null as there was no previous value.
127   dictionarySorted.setPointer(nullptr);
128   return Attribute();
129 }
130 
131 Attribute NamedAttrList::set(StringRef name, Attribute value) {
132   assert(value && "attributes may never be null");
133   return set(mlir::StringAttr::get(value.getContext(), name), value);
134 }
135 
136 Attribute
137 NamedAttrList::eraseImpl(SmallVectorImpl<NamedAttribute>::iterator it) {
138   // Erasing does not affect the sorted property.
139   Attribute attr = it->getValue();
140   attrs.erase(it);
141   dictionarySorted.setPointer(nullptr);
142   return attr;
143 }
144 
145 Attribute NamedAttrList::erase(StringAttr name) {
146   auto it = findAttr(*this, name);
147   return it.second ? eraseImpl(it.first) : Attribute();
148 }
149 
150 Attribute NamedAttrList::erase(StringRef name) {
151   auto it = findAttr(*this, name);
152   return it.second ? eraseImpl(it.first) : Attribute();
153 }
154 
155 NamedAttrList &
156 NamedAttrList::operator=(const SmallVectorImpl<NamedAttribute> &rhs) {
157   assign(rhs.begin(), rhs.end());
158   return *this;
159 }
160 
161 NamedAttrList::operator ArrayRef<NamedAttribute>() const { return attrs; }
162 
163 //===----------------------------------------------------------------------===//
164 // OperationState
165 //===----------------------------------------------------------------------===//
166 
167 OperationState::OperationState(Location location, StringRef name)
168     : location(location), name(name, location->getContext()) {}
169 
170 OperationState::OperationState(Location location, OperationName name)
171     : location(location), name(name) {}
172 
173 OperationState::OperationState(Location location, OperationName name,
174                                ValueRange operands, TypeRange types,
175                                ArrayRef<NamedAttribute> attributes,
176                                BlockRange successors,
177                                MutableArrayRef<std::unique_ptr<Region>> regions)
178     : location(location), name(name),
179       operands(operands.begin(), operands.end()),
180       types(types.begin(), types.end()),
181       attributes(attributes.begin(), attributes.end()),
182       successors(successors.begin(), successors.end()) {
183   for (std::unique_ptr<Region> &r : regions)
184     this->regions.push_back(std::move(r));
185 }
186 OperationState::OperationState(Location location, StringRef name,
187                                ValueRange operands, TypeRange types,
188                                ArrayRef<NamedAttribute> attributes,
189                                BlockRange successors,
190                                MutableArrayRef<std::unique_ptr<Region>> regions)
191     : OperationState(location, OperationName(name, location.getContext()),
192                      operands, types, attributes, successors, regions) {}
193 
194 void OperationState::addOperands(ValueRange newOperands) {
195   operands.append(newOperands.begin(), newOperands.end());
196 }
197 
198 void OperationState::addSuccessors(BlockRange newSuccessors) {
199   successors.append(newSuccessors.begin(), newSuccessors.end());
200 }
201 
202 Region *OperationState::addRegion() {
203   regions.emplace_back(new Region);
204   return regions.back().get();
205 }
206 
207 void OperationState::addRegion(std::unique_ptr<Region> &&region) {
208   regions.push_back(std::move(region));
209 }
210 
211 void OperationState::addRegions(
212     MutableArrayRef<std::unique_ptr<Region>> regions) {
213   for (std::unique_ptr<Region> &region : regions)
214     addRegion(std::move(region));
215 }
216 
217 //===----------------------------------------------------------------------===//
218 // OperandStorage
219 //===----------------------------------------------------------------------===//
220 
221 detail::OperandStorage::OperandStorage(Operation *owner,
222                                        OpOperand *trailingOperands,
223                                        ValueRange values)
224     : isStorageDynamic(false), operandStorage(trailingOperands) {
225   numOperands = capacity = values.size();
226   for (unsigned i = 0; i < numOperands; ++i)
227     new (&operandStorage[i]) OpOperand(owner, values[i]);
228 }
229 
230 detail::OperandStorage::~OperandStorage() {
231   for (auto &operand : getOperands())
232     operand.~OpOperand();
233 
234   // If the storage is dynamic, deallocate it.
235   if (isStorageDynamic)
236     free(operandStorage);
237 }
238 
239 /// Replace the operands contained in the storage with the ones provided in
240 /// 'values'.
241 void detail::OperandStorage::setOperands(Operation *owner, ValueRange values) {
242   MutableArrayRef<OpOperand> storageOperands = resize(owner, values.size());
243   for (unsigned i = 0, e = values.size(); i != e; ++i)
244     storageOperands[i].set(values[i]);
245 }
246 
247 /// Replace the operands beginning at 'start' and ending at 'start' + 'length'
248 /// with the ones provided in 'operands'. 'operands' may be smaller or larger
249 /// than the range pointed to by 'start'+'length'.
250 void detail::OperandStorage::setOperands(Operation *owner, unsigned start,
251                                          unsigned length, ValueRange operands) {
252   // If the new size is the same, we can update inplace.
253   unsigned newSize = operands.size();
254   if (newSize == length) {
255     MutableArrayRef<OpOperand> storageOperands = getOperands();
256     for (unsigned i = 0, e = length; i != e; ++i)
257       storageOperands[start + i].set(operands[i]);
258     return;
259   }
260   // If the new size is greater, remove the extra operands and set the rest
261   // inplace.
262   if (newSize < length) {
263     eraseOperands(start + operands.size(), length - newSize);
264     setOperands(owner, start, newSize, operands);
265     return;
266   }
267   // Otherwise, the new size is greater so we need to grow the storage.
268   auto storageOperands = resize(owner, size() + (newSize - length));
269 
270   // Shift operands to the right to make space for the new operands.
271   unsigned rotateSize = storageOperands.size() - (start + length);
272   auto rbegin = storageOperands.rbegin();
273   std::rotate(rbegin, std::next(rbegin, newSize - length), rbegin + rotateSize);
274 
275   // Update the operands inplace.
276   for (unsigned i = 0, e = operands.size(); i != e; ++i)
277     storageOperands[start + i].set(operands[i]);
278 }
279 
280 /// Erase an operand held by the storage.
281 void detail::OperandStorage::eraseOperands(unsigned start, unsigned length) {
282   MutableArrayRef<OpOperand> operands = getOperands();
283   assert((start + length) <= operands.size());
284   numOperands -= length;
285 
286   // Shift all operands down if the operand to remove is not at the end.
287   if (start != numOperands) {
288     auto *indexIt = std::next(operands.begin(), start);
289     std::rotate(indexIt, std::next(indexIt, length), operands.end());
290   }
291   for (unsigned i = 0; i != length; ++i)
292     operands[numOperands + i].~OpOperand();
293 }
294 
295 void detail::OperandStorage::eraseOperands(
296     const BitVector &eraseIndices) {
297   MutableArrayRef<OpOperand> operands = getOperands();
298   assert(eraseIndices.size() == operands.size());
299 
300   // Check that at least one operand is erased.
301   int firstErasedIndice = eraseIndices.find_first();
302   if (firstErasedIndice == -1)
303     return;
304 
305   // Shift all of the removed operands to the end, and destroy them.
306   numOperands = firstErasedIndice;
307   for (unsigned i = firstErasedIndice + 1, e = operands.size(); i < e; ++i)
308     if (!eraseIndices.test(i))
309       operands[numOperands++] = std::move(operands[i]);
310   for (OpOperand &operand : operands.drop_front(numOperands))
311     operand.~OpOperand();
312 }
313 
314 /// Resize the storage to the given size. Returns the array containing the new
315 /// operands.
316 MutableArrayRef<OpOperand> detail::OperandStorage::resize(Operation *owner,
317                                                           unsigned newSize) {
318   // If the number of operands is less than or equal to the current amount, we
319   // can just update in place.
320   MutableArrayRef<OpOperand> origOperands = getOperands();
321   if (newSize <= numOperands) {
322     // If the number of new size is less than the current, remove any extra
323     // operands.
324     for (unsigned i = newSize; i != numOperands; ++i)
325       origOperands[i].~OpOperand();
326     numOperands = newSize;
327     return origOperands.take_front(newSize);
328   }
329 
330   // If the new size is within the original inline capacity, grow inplace.
331   if (newSize <= capacity) {
332     OpOperand *opBegin = origOperands.data();
333     for (unsigned e = newSize; numOperands != e; ++numOperands)
334       new (&opBegin[numOperands]) OpOperand(owner);
335     return MutableArrayRef<OpOperand>(opBegin, newSize);
336   }
337 
338   // Otherwise, we need to allocate a new storage.
339   unsigned newCapacity =
340       std::max(unsigned(llvm::NextPowerOf2(capacity + 2)), newSize);
341   OpOperand *newOperandStorage =
342       reinterpret_cast<OpOperand *>(malloc(sizeof(OpOperand) * newCapacity));
343 
344   // Move the current operands to the new storage.
345   MutableArrayRef<OpOperand> newOperands(newOperandStorage, newSize);
346   std::uninitialized_copy(std::make_move_iterator(origOperands.begin()),
347                           std::make_move_iterator(origOperands.end()),
348                           newOperands.begin());
349 
350   // Destroy the original operands.
351   for (auto &operand : origOperands)
352     operand.~OpOperand();
353 
354   // Initialize any new operands.
355   for (unsigned e = newSize; numOperands != e; ++numOperands)
356     new (&newOperands[numOperands]) OpOperand(owner);
357 
358   // If the current storage is dynamic, free it.
359   if (isStorageDynamic)
360     free(operandStorage);
361 
362   // Update the storage representation to use the new dynamic storage.
363   operandStorage = newOperandStorage;
364   capacity = newCapacity;
365   isStorageDynamic = true;
366   return newOperands;
367 }
368 
369 //===----------------------------------------------------------------------===//
370 // Operation Value-Iterators
371 //===----------------------------------------------------------------------===//
372 
373 //===----------------------------------------------------------------------===//
374 // OperandRange
375 
376 unsigned OperandRange::getBeginOperandIndex() const {
377   assert(!empty() && "range must not be empty");
378   return base->getOperandNumber();
379 }
380 
381 OperandRangeRange OperandRange::split(ElementsAttr segmentSizes) const {
382   return OperandRangeRange(*this, segmentSizes);
383 }
384 
385 //===----------------------------------------------------------------------===//
386 // OperandRangeRange
387 
388 OperandRangeRange::OperandRangeRange(OperandRange operands,
389                                      Attribute operandSegments)
390     : OperandRangeRange(OwnerT(operands.getBase(), operandSegments), 0,
391                         operandSegments.cast<DenseElementsAttr>().size()) {}
392 
393 OperandRange OperandRangeRange::join() const {
394   const OwnerT &owner = getBase();
395   auto sizeData = owner.second.cast<DenseElementsAttr>().getValues<uint32_t>();
396   return OperandRange(owner.first,
397                       std::accumulate(sizeData.begin(), sizeData.end(), 0));
398 }
399 
400 OperandRange OperandRangeRange::dereference(const OwnerT &object,
401                                             ptrdiff_t index) {
402   auto sizeData = object.second.cast<DenseElementsAttr>().getValues<uint32_t>();
403   uint32_t startIndex =
404       std::accumulate(sizeData.begin(), sizeData.begin() + index, 0);
405   return OperandRange(object.first + startIndex, *(sizeData.begin() + index));
406 }
407 
408 //===----------------------------------------------------------------------===//
409 // MutableOperandRange
410 
411 /// Construct a new mutable range from the given operand, operand start index,
412 /// and range length.
413 MutableOperandRange::MutableOperandRange(
414     Operation *owner, unsigned start, unsigned length,
415     ArrayRef<OperandSegment> operandSegments)
416     : owner(owner), start(start), length(length),
417       operandSegments(operandSegments.begin(), operandSegments.end()) {
418   assert((start + length) <= owner->getNumOperands() && "invalid range");
419 }
420 MutableOperandRange::MutableOperandRange(Operation *owner)
421     : MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {}
422 
423 /// Slice this range into a sub range, with the additional operand segment.
424 MutableOperandRange
425 MutableOperandRange::slice(unsigned subStart, unsigned subLen,
426                            Optional<OperandSegment> segment) const {
427   assert((subStart + subLen) <= length && "invalid sub-range");
428   MutableOperandRange subSlice(owner, start + subStart, subLen,
429                                operandSegments);
430   if (segment)
431     subSlice.operandSegments.push_back(*segment);
432   return subSlice;
433 }
434 
435 /// Append the given values to the range.
436 void MutableOperandRange::append(ValueRange values) {
437   if (values.empty())
438     return;
439   owner->insertOperands(start + length, values);
440   updateLength(length + values.size());
441 }
442 
443 /// Assign this range to the given values.
444 void MutableOperandRange::assign(ValueRange values) {
445   owner->setOperands(start, length, values);
446   if (length != values.size())
447     updateLength(/*newLength=*/values.size());
448 }
449 
450 /// Assign the range to the given value.
451 void MutableOperandRange::assign(Value value) {
452   if (length == 1) {
453     owner->setOperand(start, value);
454   } else {
455     owner->setOperands(start, length, value);
456     updateLength(/*newLength=*/1);
457   }
458 }
459 
460 /// Erase the operands within the given sub-range.
461 void MutableOperandRange::erase(unsigned subStart, unsigned subLen) {
462   assert((subStart + subLen) <= length && "invalid sub-range");
463   if (length == 0)
464     return;
465   owner->eraseOperands(start + subStart, subLen);
466   updateLength(length - subLen);
467 }
468 
469 /// Clear this range and erase all of the operands.
470 void MutableOperandRange::clear() {
471   if (length != 0) {
472     owner->eraseOperands(start, length);
473     updateLength(/*newLength=*/0);
474   }
475 }
476 
477 /// Allow implicit conversion to an OperandRange.
478 MutableOperandRange::operator OperandRange() const {
479   return owner->getOperands().slice(start, length);
480 }
481 
482 MutableOperandRangeRange
483 MutableOperandRange::split(NamedAttribute segmentSizes) const {
484   return MutableOperandRangeRange(*this, segmentSizes);
485 }
486 
487 /// Update the length of this range to the one provided.
488 void MutableOperandRange::updateLength(unsigned newLength) {
489   int32_t diff = int32_t(newLength) - int32_t(length);
490   length = newLength;
491 
492   // Update any of the provided segment attributes.
493   for (OperandSegment &segment : operandSegments) {
494     auto attr = segment.second.getValue().cast<DenseIntElementsAttr>();
495     SmallVector<int32_t, 8> segments(attr.getValues<int32_t>());
496     segments[segment.first] += diff;
497     segment.second.setValue(
498         DenseIntElementsAttr::get(attr.getType(), segments));
499     owner->setAttr(segment.second.getName(), segment.second.getValue());
500   }
501 }
502 
503 //===----------------------------------------------------------------------===//
504 // MutableOperandRangeRange
505 
506 MutableOperandRangeRange::MutableOperandRangeRange(
507     const MutableOperandRange &operands, NamedAttribute operandSegmentAttr)
508     : MutableOperandRangeRange(
509           OwnerT(operands, operandSegmentAttr), 0,
510           operandSegmentAttr.getValue().cast<DenseElementsAttr>().size()) {}
511 
512 MutableOperandRange MutableOperandRangeRange::join() const {
513   return getBase().first;
514 }
515 
516 MutableOperandRangeRange::operator OperandRangeRange() const {
517   return OperandRangeRange(
518       getBase().first, getBase().second.getValue().cast<DenseElementsAttr>());
519 }
520 
521 MutableOperandRange MutableOperandRangeRange::dereference(const OwnerT &object,
522                                                           ptrdiff_t index) {
523   auto sizeData =
524       object.second.getValue().cast<DenseElementsAttr>().getValues<uint32_t>();
525   uint32_t startIndex =
526       std::accumulate(sizeData.begin(), sizeData.begin() + index, 0);
527   return object.first.slice(
528       startIndex, *(sizeData.begin() + index),
529       MutableOperandRange::OperandSegment(index, object.second));
530 }
531 
532 //===----------------------------------------------------------------------===//
533 // ResultRange
534 
535 ResultRange::ResultRange(OpResult result)
536     : ResultRange(static_cast<detail::OpResultImpl *>(Value(result).getImpl()),
537                   1) {}
538 
539 ResultRange::use_range ResultRange::getUses() const {
540   return {use_begin(), use_end()};
541 }
542 ResultRange::use_iterator ResultRange::use_begin() const {
543   return use_iterator(*this);
544 }
545 ResultRange::use_iterator ResultRange::use_end() const {
546   return use_iterator(*this, /*end=*/true);
547 }
548 ResultRange::user_range ResultRange::getUsers() {
549   return {user_begin(), user_end()};
550 }
551 ResultRange::user_iterator ResultRange::user_begin() {
552   return user_iterator(use_begin());
553 }
554 ResultRange::user_iterator ResultRange::user_end() {
555   return user_iterator(use_end());
556 }
557 
558 ResultRange::UseIterator::UseIterator(ResultRange results, bool end)
559     : it(end ? results.end() : results.begin()), endIt(results.end()) {
560   // Only initialize current use if there are results/can be uses.
561   if (it != endIt)
562     skipOverResultsWithNoUsers();
563 }
564 
565 ResultRange::UseIterator &ResultRange::UseIterator::operator++() {
566   // We increment over uses, if we reach the last use then move to next
567   // result.
568   if (use != (*it).use_end())
569     ++use;
570   if (use == (*it).use_end()) {
571     ++it;
572     skipOverResultsWithNoUsers();
573   }
574   return *this;
575 }
576 
577 void ResultRange::UseIterator::skipOverResultsWithNoUsers() {
578   while (it != endIt && (*it).use_empty())
579     ++it;
580 
581   // If we are at the last result, then set use to first use of
582   // first result (sentinel value used for end).
583   if (it == endIt)
584     use = {};
585   else
586     use = (*it).use_begin();
587 }
588 
589 void ResultRange::replaceAllUsesWith(Operation *op) {
590   replaceAllUsesWith(op->getResults());
591 }
592 
593 //===----------------------------------------------------------------------===//
594 // ValueRange
595 
596 ValueRange::ValueRange(ArrayRef<Value> values)
597     : ValueRange(values.data(), values.size()) {}
598 ValueRange::ValueRange(OperandRange values)
599     : ValueRange(values.begin().getBase(), values.size()) {}
600 ValueRange::ValueRange(ResultRange values)
601     : ValueRange(values.getBase(), values.size()) {}
602 
603 /// See `llvm::detail::indexed_accessor_range_base` for details.
604 ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
605                                            ptrdiff_t index) {
606   if (const auto *value = owner.dyn_cast<const Value *>())
607     return {value + index};
608   if (auto *operand = owner.dyn_cast<OpOperand *>())
609     return {operand + index};
610   return owner.get<detail::OpResultImpl *>()->getNextResultAtOffset(index);
611 }
612 /// See `llvm::detail::indexed_accessor_range_base` for details.
613 Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
614   if (const auto *value = owner.dyn_cast<const Value *>())
615     return value[index];
616   if (auto *operand = owner.dyn_cast<OpOperand *>())
617     return operand[index].get();
618   return owner.get<detail::OpResultImpl *>()->getNextResultAtOffset(index);
619 }
620 
621 //===----------------------------------------------------------------------===//
622 // Operation Equivalency
623 //===----------------------------------------------------------------------===//
624 
625 llvm::hash_code OperationEquivalence::computeHash(
626     Operation *op, function_ref<llvm::hash_code(Value)> hashOperands,
627     function_ref<llvm::hash_code(Value)> hashResults, Flags flags) {
628   // Hash operations based upon their:
629   //   - Operation Name
630   //   - Attributes
631   //   - Result Types
632   llvm::hash_code hash = llvm::hash_combine(
633       op->getName(), op->getAttrDictionary(), op->getResultTypes());
634 
635   //   - Operands
636   ValueRange operands = op->getOperands();
637   SmallVector<Value> operandStorage;
638   if (op->hasTrait<mlir::OpTrait::IsCommutative>()) {
639     operandStorage.append(operands.begin(), operands.end());
640     llvm::sort(operandStorage, [](Value a, Value b) -> bool {
641       return a.getAsOpaquePointer() < b.getAsOpaquePointer();
642     });
643     operands = operandStorage;
644   }
645   for (Value operand : operands)
646     hash = llvm::hash_combine(hash, hashOperands(operand));
647 
648   //   - Operands
649   for (Value result : op->getResults())
650     hash = llvm::hash_combine(hash, hashResults(result));
651   return hash;
652 }
653 
654 static bool
655 isRegionEquivalentTo(Region *lhs, Region *rhs,
656                      function_ref<LogicalResult(Value, Value)> mapOperands,
657                      function_ref<LogicalResult(Value, Value)> mapResults,
658                      OperationEquivalence::Flags flags) {
659   DenseMap<Block *, Block *> blocksMap;
660   auto blocksEquivalent = [&](Block &lBlock, Block &rBlock) {
661     // Check block arguments.
662     if (lBlock.getNumArguments() != rBlock.getNumArguments())
663       return false;
664 
665     // Map the two blocks.
666     auto insertion = blocksMap.insert({&lBlock, &rBlock});
667     if (insertion.first->getSecond() != &rBlock)
668       return false;
669 
670     for (auto argPair :
671          llvm::zip(lBlock.getArguments(), rBlock.getArguments())) {
672       Value curArg = std::get<0>(argPair);
673       Value otherArg = std::get<1>(argPair);
674       if (curArg.getType() != otherArg.getType())
675         return false;
676       if (!(flags & OperationEquivalence::IgnoreLocations) &&
677           curArg.getLoc() != otherArg.getLoc())
678         return false;
679       // Check if this value was already mapped to another value.
680       if (failed(mapOperands(curArg, otherArg)))
681         return false;
682     }
683 
684     auto opsEquivalent = [&](Operation &lOp, Operation &rOp) {
685       // Check for op equality (recursively).
686       if (!OperationEquivalence::isEquivalentTo(&lOp, &rOp, mapOperands,
687                                                 mapResults, flags))
688         return false;
689       // Check successor mapping.
690       for (auto successorsPair :
691            llvm::zip(lOp.getSuccessors(), rOp.getSuccessors())) {
692         Block *curSuccessor = std::get<0>(successorsPair);
693         Block *otherSuccessor = std::get<1>(successorsPair);
694         auto insertion = blocksMap.insert({curSuccessor, otherSuccessor});
695         if (insertion.first->getSecond() != otherSuccessor)
696           return false;
697       }
698       return true;
699     };
700     return llvm::all_of_zip(lBlock, rBlock, opsEquivalent);
701   };
702   return llvm::all_of_zip(*lhs, *rhs, blocksEquivalent);
703 }
704 
705 bool OperationEquivalence::isEquivalentTo(
706     Operation *lhs, Operation *rhs,
707     function_ref<LogicalResult(Value, Value)> mapOperands,
708     function_ref<LogicalResult(Value, Value)> mapResults, Flags flags) {
709   if (lhs == rhs)
710     return true;
711 
712   // Compare the operation properties.
713   if (lhs->getName() != rhs->getName() ||
714       lhs->getAttrDictionary() != rhs->getAttrDictionary() ||
715       lhs->getNumRegions() != rhs->getNumRegions() ||
716       lhs->getNumSuccessors() != rhs->getNumSuccessors() ||
717       lhs->getNumOperands() != rhs->getNumOperands() ||
718       lhs->getNumResults() != rhs->getNumResults())
719     return false;
720   if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc())
721     return false;
722 
723   ValueRange lhsOperands = lhs->getOperands(), rhsOperands = rhs->getOperands();
724   SmallVector<Value> lhsOperandStorage, rhsOperandStorage;
725   if (lhs->hasTrait<mlir::OpTrait::IsCommutative>()) {
726     lhsOperandStorage.append(lhsOperands.begin(), lhsOperands.end());
727     llvm::sort(lhsOperandStorage, [](Value a, Value b) -> bool {
728       return a.getAsOpaquePointer() < b.getAsOpaquePointer();
729     });
730     lhsOperands = lhsOperandStorage;
731 
732     rhsOperandStorage.append(rhsOperands.begin(), rhsOperands.end());
733     llvm::sort(rhsOperandStorage, [](Value a, Value b) -> bool {
734       return a.getAsOpaquePointer() < b.getAsOpaquePointer();
735     });
736     rhsOperands = rhsOperandStorage;
737   }
738   auto checkValueRangeMapping =
739       [](ValueRange lhs, ValueRange rhs,
740          function_ref<LogicalResult(Value, Value)> mapValues) {
741         for (auto operandPair : llvm::zip(lhs, rhs)) {
742           Value curArg = std::get<0>(operandPair);
743           Value otherArg = std::get<1>(operandPair);
744           if (curArg.getType() != otherArg.getType())
745             return false;
746           if (failed(mapValues(curArg, otherArg)))
747             return false;
748         }
749         return true;
750       };
751   // Check mapping of operands and results.
752   if (!checkValueRangeMapping(lhsOperands, rhsOperands, mapOperands))
753     return false;
754   if (!checkValueRangeMapping(lhs->getResults(), rhs->getResults(), mapResults))
755     return false;
756   for (auto regionPair : llvm::zip(lhs->getRegions(), rhs->getRegions()))
757     if (!isRegionEquivalentTo(&std::get<0>(regionPair),
758                               &std::get<1>(regionPair), mapOperands, mapResults,
759                               flags))
760       return false;
761   return true;
762 }
763