1 //===- OperationSupport.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains out-of-line implementations of the support types that
10 // Operation and related classes build on top of.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/IR/OperationSupport.h"
15 #include "mlir/IR/BuiltinAttributes.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "llvm/ADT/BitVector.h"
19 #include <numeric>
20 
21 using namespace mlir;
22 
23 //===----------------------------------------------------------------------===//
24 // NamedAttrList
25 //===----------------------------------------------------------------------===//
26 
27 NamedAttrList::NamedAttrList(ArrayRef<NamedAttribute> attributes) {
28   assign(attributes.begin(), attributes.end());
29 }
30 
31 NamedAttrList::NamedAttrList(DictionaryAttr attributes)
32     : NamedAttrList(attributes ? attributes.getValue()
33                                : ArrayRef<NamedAttribute>()) {
34   dictionarySorted.setPointerAndInt(attributes, true);
35 }
36 
37 NamedAttrList::NamedAttrList(const_iterator in_start, const_iterator in_end) {
38   assign(in_start, in_end);
39 }
40 
41 ArrayRef<NamedAttribute> NamedAttrList::getAttrs() const { return attrs; }
42 
43 Optional<NamedAttribute> NamedAttrList::findDuplicate() const {
44   Optional<NamedAttribute> duplicate =
45       DictionaryAttr::findDuplicate(attrs, isSorted());
46   // DictionaryAttr::findDuplicate will sort the list, so reset the sorted
47   // state.
48   if (!isSorted())
49     dictionarySorted.setPointerAndInt(nullptr, true);
50   return duplicate;
51 }
52 
53 DictionaryAttr NamedAttrList::getDictionary(MLIRContext *context) const {
54   if (!isSorted()) {
55     DictionaryAttr::sortInPlace(attrs);
56     dictionarySorted.setPointerAndInt(nullptr, true);
57   }
58   if (!dictionarySorted.getPointer())
59     dictionarySorted.setPointer(DictionaryAttr::getWithSorted(context, attrs));
60   return dictionarySorted.getPointer().cast<DictionaryAttr>();
61 }
62 
63 /// Add an attribute with the specified name.
64 void NamedAttrList::append(StringRef name, Attribute attr) {
65   append(Identifier::get(name, attr.getContext()), attr);
66 }
67 
68 /// Replaces the attributes with new list of attributes.
69 void NamedAttrList::assign(const_iterator in_start, const_iterator in_end) {
70   DictionaryAttr::sort(ArrayRef<NamedAttribute>{in_start, in_end}, attrs);
71   dictionarySorted.setPointerAndInt(nullptr, true);
72 }
73 
74 void NamedAttrList::push_back(NamedAttribute newAttribute) {
75   assert(newAttribute.second && "unexpected null attribute");
76   if (isSorted())
77     dictionarySorted.setInt(
78         attrs.empty() ||
79         strcmp(attrs.back().first.data(), newAttribute.first.data()) < 0);
80   dictionarySorted.setPointer(nullptr);
81   attrs.push_back(newAttribute);
82 }
83 
84 /// Helper function to find attribute in possible sorted vector of
85 /// NamedAttributes.
86 template <typename T>
87 static auto *findAttr(SmallVectorImpl<NamedAttribute> &attrs, T name,
88                       bool sorted) {
89   if (!sorted) {
90     return llvm::find_if(
91         attrs, [name](NamedAttribute attr) { return attr.first == name; });
92   }
93 
94   auto *it = llvm::lower_bound(attrs, name);
95   if (it == attrs.end() || it->first != name)
96     return attrs.end();
97   return it;
98 }
99 
100 /// Return the specified attribute if present, null otherwise.
101 Attribute NamedAttrList::get(StringRef name) const {
102   auto *it = findAttr(attrs, name, isSorted());
103   return it != attrs.end() ? it->second : nullptr;
104 }
105 
106 /// Return the specified attribute if present, null otherwise.
107 Attribute NamedAttrList::get(Identifier name) const {
108   auto *it = findAttr(attrs, name, isSorted());
109   return it != attrs.end() ? it->second : nullptr;
110 }
111 
112 /// Return the specified named attribute if present, None otherwise.
113 Optional<NamedAttribute> NamedAttrList::getNamed(StringRef name) const {
114   auto *it = findAttr(attrs, name, isSorted());
115   return it != attrs.end() ? *it : Optional<NamedAttribute>();
116 }
117 Optional<NamedAttribute> NamedAttrList::getNamed(Identifier name) const {
118   auto *it = findAttr(attrs, name, isSorted());
119   return it != attrs.end() ? *it : Optional<NamedAttribute>();
120 }
121 
122 /// If the an attribute exists with the specified name, change it to the new
123 /// value.  Otherwise, add a new attribute with the specified name/value.
124 Attribute NamedAttrList::set(Identifier name, Attribute value) {
125   assert(value && "attributes may never be null");
126 
127   // Look for an existing value for the given name, and set it in-place.
128   auto *it = findAttr(attrs, name, isSorted());
129   if (it != attrs.end()) {
130     // Only update if the value is different from the existing.
131     Attribute oldValue = it->second;
132     if (oldValue != value) {
133       dictionarySorted.setPointer(nullptr);
134       it->second = value;
135     }
136     return oldValue;
137   }
138 
139   // Otherwise, insert the new attribute into its sorted position.
140   it = llvm::lower_bound(attrs, name);
141   dictionarySorted.setPointer(nullptr);
142   attrs.insert(it, {name, value});
143   return Attribute();
144 }
145 Attribute NamedAttrList::set(StringRef name, Attribute value) {
146   assert(value && "setting null attribute not supported");
147   return set(mlir::Identifier::get(name, value.getContext()), value);
148 }
149 
150 Attribute
151 NamedAttrList::eraseImpl(SmallVectorImpl<NamedAttribute>::iterator it) {
152   if (it == attrs.end())
153     return nullptr;
154 
155   // Erasing does not affect the sorted property.
156   Attribute attr = it->second;
157   attrs.erase(it);
158   dictionarySorted.setPointer(nullptr);
159   return attr;
160 }
161 
162 Attribute NamedAttrList::erase(Identifier name) {
163   return eraseImpl(findAttr(attrs, name, isSorted()));
164 }
165 
166 Attribute NamedAttrList::erase(StringRef name) {
167   return eraseImpl(findAttr(attrs, name, isSorted()));
168 }
169 
170 NamedAttrList &
171 NamedAttrList::operator=(const SmallVectorImpl<NamedAttribute> &rhs) {
172   assign(rhs.begin(), rhs.end());
173   return *this;
174 }
175 
176 NamedAttrList::operator ArrayRef<NamedAttribute>() const { return attrs; }
177 
178 //===----------------------------------------------------------------------===//
179 // OperationState
180 //===----------------------------------------------------------------------===//
181 
182 OperationState::OperationState(Location location, StringRef name)
183     : location(location), name(name, location->getContext()) {}
184 
185 OperationState::OperationState(Location location, OperationName name)
186     : location(location), name(name) {}
187 
188 OperationState::OperationState(Location location, StringRef name,
189                                ValueRange operands, TypeRange types,
190                                ArrayRef<NamedAttribute> attributes,
191                                BlockRange successors,
192                                MutableArrayRef<std::unique_ptr<Region>> regions)
193     : location(location), name(name, location->getContext()),
194       operands(operands.begin(), operands.end()),
195       types(types.begin(), types.end()),
196       attributes(attributes.begin(), attributes.end()),
197       successors(successors.begin(), successors.end()) {
198   for (std::unique_ptr<Region> &r : regions)
199     this->regions.push_back(std::move(r));
200 }
201 
202 void OperationState::addOperands(ValueRange newOperands) {
203   operands.append(newOperands.begin(), newOperands.end());
204 }
205 
206 void OperationState::addSuccessors(BlockRange newSuccessors) {
207   successors.append(newSuccessors.begin(), newSuccessors.end());
208 }
209 
210 Region *OperationState::addRegion() {
211   regions.emplace_back(new Region);
212   return regions.back().get();
213 }
214 
215 void OperationState::addRegion(std::unique_ptr<Region> &&region) {
216   regions.push_back(std::move(region));
217 }
218 
219 void OperationState::addRegions(
220     MutableArrayRef<std::unique_ptr<Region>> regions) {
221   for (std::unique_ptr<Region> &region : regions)
222     addRegion(std::move(region));
223 }
224 
225 //===----------------------------------------------------------------------===//
226 // OperandStorage
227 //===----------------------------------------------------------------------===//
228 
229 detail::OperandStorage::OperandStorage(Operation *owner, ValueRange values)
230     : inlineStorage() {
231   auto &inlineStorage = getInlineStorage();
232   inlineStorage.numOperands = inlineStorage.capacity = values.size();
233   auto *operandPtrBegin = getTrailingObjects<OpOperand>();
234   for (unsigned i = 0, e = inlineStorage.numOperands; i < e; ++i)
235     new (&operandPtrBegin[i]) OpOperand(owner, values[i]);
236 }
237 
238 detail::OperandStorage::~OperandStorage() {
239   // Destruct the current storage container.
240   if (isDynamicStorage()) {
241     TrailingOperandStorage &storage = getDynamicStorage();
242     storage.~TrailingOperandStorage();
243     // Work around -Wfree-nonheap-object false positive fixed by D102728.
244     auto *mem = &storage;
245     free(mem);
246   } else {
247     getInlineStorage().~TrailingOperandStorage();
248   }
249 }
250 
251 /// Replace the operands contained in the storage with the ones provided in
252 /// 'values'.
253 void detail::OperandStorage::setOperands(Operation *owner, ValueRange values) {
254   MutableArrayRef<OpOperand> storageOperands = resize(owner, values.size());
255   for (unsigned i = 0, e = values.size(); i != e; ++i)
256     storageOperands[i].set(values[i]);
257 }
258 
259 /// Replace the operands beginning at 'start' and ending at 'start' + 'length'
260 /// with the ones provided in 'operands'. 'operands' may be smaller or larger
261 /// than the range pointed to by 'start'+'length'.
262 void detail::OperandStorage::setOperands(Operation *owner, unsigned start,
263                                          unsigned length, ValueRange operands) {
264   // If the new size is the same, we can update inplace.
265   unsigned newSize = operands.size();
266   if (newSize == length) {
267     MutableArrayRef<OpOperand> storageOperands = getOperands();
268     for (unsigned i = 0, e = length; i != e; ++i)
269       storageOperands[start + i].set(operands[i]);
270     return;
271   }
272   // If the new size is greater, remove the extra operands and set the rest
273   // inplace.
274   if (newSize < length) {
275     eraseOperands(start + operands.size(), length - newSize);
276     setOperands(owner, start, newSize, operands);
277     return;
278   }
279   // Otherwise, the new size is greater so we need to grow the storage.
280   auto storageOperands = resize(owner, size() + (newSize - length));
281 
282   // Shift operands to the right to make space for the new operands.
283   unsigned rotateSize = storageOperands.size() - (start + length);
284   auto rbegin = storageOperands.rbegin();
285   std::rotate(rbegin, std::next(rbegin, newSize - length), rbegin + rotateSize);
286 
287   // Update the operands inplace.
288   for (unsigned i = 0, e = operands.size(); i != e; ++i)
289     storageOperands[start + i].set(operands[i]);
290 }
291 
292 /// Erase an operand held by the storage.
293 void detail::OperandStorage::eraseOperands(unsigned start, unsigned length) {
294   TrailingOperandStorage &storage = getStorage();
295   MutableArrayRef<OpOperand> operands = storage.getOperands();
296   assert((start + length) <= operands.size());
297   storage.numOperands -= length;
298 
299   // Shift all operands down if the operand to remove is not at the end.
300   if (start != storage.numOperands) {
301     auto *indexIt = std::next(operands.begin(), start);
302     std::rotate(indexIt, std::next(indexIt, length), operands.end());
303   }
304   for (unsigned i = 0; i != length; ++i)
305     operands[storage.numOperands + i].~OpOperand();
306 }
307 
308 void detail::OperandStorage::eraseOperands(
309     const llvm::BitVector &eraseIndices) {
310   TrailingOperandStorage &storage = getStorage();
311   MutableArrayRef<OpOperand> operands = storage.getOperands();
312   assert(eraseIndices.size() == operands.size());
313 
314   // Check that at least one operand is erased.
315   int firstErasedIndice = eraseIndices.find_first();
316   if (firstErasedIndice == -1)
317     return;
318 
319   // Shift all of the removed operands to the end, and destroy them.
320   storage.numOperands = firstErasedIndice;
321   for (unsigned i = firstErasedIndice + 1, e = operands.size(); i < e; ++i)
322     if (!eraseIndices.test(i))
323       operands[storage.numOperands++] = std::move(operands[i]);
324   for (OpOperand &operand : operands.drop_front(storage.numOperands))
325     operand.~OpOperand();
326 }
327 
328 /// Resize the storage to the given size. Returns the array containing the new
329 /// operands.
330 MutableArrayRef<OpOperand> detail::OperandStorage::resize(Operation *owner,
331                                                           unsigned newSize) {
332   TrailingOperandStorage &storage = getStorage();
333 
334   // If the number of operands is less than or equal to the current amount, we
335   // can just update in place.
336   unsigned &numOperands = storage.numOperands;
337   MutableArrayRef<OpOperand> operands = storage.getOperands();
338   if (newSize <= numOperands) {
339     // If the number of new size is less than the current, remove any extra
340     // operands.
341     for (unsigned i = newSize; i != numOperands; ++i)
342       operands[i].~OpOperand();
343     numOperands = newSize;
344     return operands.take_front(newSize);
345   }
346 
347   // If the new size is within the original inline capacity, grow inplace.
348   if (newSize <= storage.capacity) {
349     OpOperand *opBegin = operands.data();
350     for (unsigned e = newSize; numOperands != e; ++numOperands)
351       new (&opBegin[numOperands]) OpOperand(owner);
352     return MutableArrayRef<OpOperand>(opBegin, newSize);
353   }
354 
355   // Otherwise, we need to allocate a new storage.
356   unsigned newCapacity =
357       std::max(unsigned(llvm::NextPowerOf2(storage.capacity + 2)), newSize);
358   auto *newStorageMem =
359       malloc(TrailingOperandStorage::totalSizeToAlloc<OpOperand>(newCapacity));
360   auto *newStorage = ::new (newStorageMem) TrailingOperandStorage();
361   newStorage->numOperands = newSize;
362   newStorage->capacity = newCapacity;
363 
364   // Move the current operands to the new storage.
365   MutableArrayRef<OpOperand> newOperands = newStorage->getOperands();
366   std::uninitialized_copy(std::make_move_iterator(operands.begin()),
367                           std::make_move_iterator(operands.end()),
368                           newOperands.begin());
369 
370   // Destroy the original operands.
371   for (auto &operand : operands)
372     operand.~OpOperand();
373 
374   // Initialize any new operands.
375   for (unsigned e = newSize; numOperands != e; ++numOperands)
376     new (&newOperands[numOperands]) OpOperand(owner);
377 
378   // If the current storage is also dynamic, free it.
379   if (isDynamicStorage()) {
380     // Work around -Wfree-nonheap-object false positive fixed by D102728.
381     auto *mem = &storage;
382     free(mem);
383   }
384 
385   // Update the storage representation to use the new dynamic storage.
386   dynamicStorage.setPointerAndInt(newStorage, true);
387   return newOperands;
388 }
389 
390 //===----------------------------------------------------------------------===//
391 // Operation Value-Iterators
392 //===----------------------------------------------------------------------===//
393 
394 //===----------------------------------------------------------------------===//
395 // OperandRange
396 
397 OperandRange::OperandRange(Operation *op)
398     : OperandRange(op->getOpOperands().data(), op->getNumOperands()) {}
399 
400 unsigned OperandRange::getBeginOperandIndex() const {
401   assert(!empty() && "range must not be empty");
402   return base->getOperandNumber();
403 }
404 
405 OperandRangeRange OperandRange::split(ElementsAttr segmentSizes) const {
406   return OperandRangeRange(*this, segmentSizes);
407 }
408 
409 //===----------------------------------------------------------------------===//
410 // OperandRangeRange
411 
412 OperandRangeRange::OperandRangeRange(OperandRange operands,
413                                      Attribute operandSegments)
414     : OperandRangeRange(OwnerT(operands.getBase(), operandSegments), 0,
415                         operandSegments.cast<DenseElementsAttr>().size()) {}
416 
417 OperandRange OperandRangeRange::join() const {
418   const OwnerT &owner = getBase();
419   auto sizeData = owner.second.cast<DenseElementsAttr>().getValues<uint32_t>();
420   return OperandRange(owner.first,
421                       std::accumulate(sizeData.begin(), sizeData.end(), 0));
422 }
423 
424 OperandRange OperandRangeRange::dereference(const OwnerT &object,
425                                             ptrdiff_t index) {
426   auto sizeData = object.second.cast<DenseElementsAttr>().getValues<uint32_t>();
427   uint32_t startIndex =
428       std::accumulate(sizeData.begin(), sizeData.begin() + index, 0);
429   return OperandRange(object.first + startIndex, *(sizeData.begin() + index));
430 }
431 
432 //===----------------------------------------------------------------------===//
433 // MutableOperandRange
434 
435 /// Construct a new mutable range from the given operand, operand start index,
436 /// and range length.
437 MutableOperandRange::MutableOperandRange(
438     Operation *owner, unsigned start, unsigned length,
439     ArrayRef<OperandSegment> operandSegments)
440     : owner(owner), start(start), length(length),
441       operandSegments(operandSegments.begin(), operandSegments.end()) {
442   assert((start + length) <= owner->getNumOperands() && "invalid range");
443 }
444 MutableOperandRange::MutableOperandRange(Operation *owner)
445     : MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {}
446 
447 /// Slice this range into a sub range, with the additional operand segment.
448 MutableOperandRange
449 MutableOperandRange::slice(unsigned subStart, unsigned subLen,
450                            Optional<OperandSegment> segment) const {
451   assert((subStart + subLen) <= length && "invalid sub-range");
452   MutableOperandRange subSlice(owner, start + subStart, subLen,
453                                operandSegments);
454   if (segment)
455     subSlice.operandSegments.push_back(*segment);
456   return subSlice;
457 }
458 
459 /// Append the given values to the range.
460 void MutableOperandRange::append(ValueRange values) {
461   if (values.empty())
462     return;
463   owner->insertOperands(start + length, values);
464   updateLength(length + values.size());
465 }
466 
467 /// Assign this range to the given values.
468 void MutableOperandRange::assign(ValueRange values) {
469   owner->setOperands(start, length, values);
470   if (length != values.size())
471     updateLength(/*newLength=*/values.size());
472 }
473 
474 /// Assign the range to the given value.
475 void MutableOperandRange::assign(Value value) {
476   if (length == 1) {
477     owner->setOperand(start, value);
478   } else {
479     owner->setOperands(start, length, value);
480     updateLength(/*newLength=*/1);
481   }
482 }
483 
484 /// Erase the operands within the given sub-range.
485 void MutableOperandRange::erase(unsigned subStart, unsigned subLen) {
486   assert((subStart + subLen) <= length && "invalid sub-range");
487   if (length == 0)
488     return;
489   owner->eraseOperands(start + subStart, subLen);
490   updateLength(length - subLen);
491 }
492 
493 /// Clear this range and erase all of the operands.
494 void MutableOperandRange::clear() {
495   if (length != 0) {
496     owner->eraseOperands(start, length);
497     updateLength(/*newLength=*/0);
498   }
499 }
500 
501 /// Allow implicit conversion to an OperandRange.
502 MutableOperandRange::operator OperandRange() const {
503   return owner->getOperands().slice(start, length);
504 }
505 
506 MutableOperandRangeRange
507 MutableOperandRange::split(NamedAttribute segmentSizes) const {
508   return MutableOperandRangeRange(*this, segmentSizes);
509 }
510 
511 /// Update the length of this range to the one provided.
512 void MutableOperandRange::updateLength(unsigned newLength) {
513   int32_t diff = int32_t(newLength) - int32_t(length);
514   length = newLength;
515 
516   // Update any of the provided segment attributes.
517   for (OperandSegment &segment : operandSegments) {
518     auto attr = segment.second.second.cast<DenseIntElementsAttr>();
519     SmallVector<int32_t, 8> segments(attr.getValues<int32_t>());
520     segments[segment.first] += diff;
521     segment.second.second = DenseIntElementsAttr::get(attr.getType(), segments);
522     owner->setAttr(segment.second.first, segment.second.second);
523   }
524 }
525 
526 //===----------------------------------------------------------------------===//
527 // MutableOperandRangeRange
528 
529 MutableOperandRangeRange::MutableOperandRangeRange(
530     const MutableOperandRange &operands, NamedAttribute operandSegmentAttr)
531     : MutableOperandRangeRange(
532           OwnerT(operands, operandSegmentAttr), 0,
533           operandSegmentAttr.second.cast<DenseElementsAttr>().size()) {}
534 
535 MutableOperandRange MutableOperandRangeRange::join() const {
536   return getBase().first;
537 }
538 
539 MutableOperandRangeRange::operator OperandRangeRange() const {
540   return OperandRangeRange(getBase().first,
541                            getBase().second.second.cast<DenseElementsAttr>());
542 }
543 
544 MutableOperandRange MutableOperandRangeRange::dereference(const OwnerT &object,
545                                                           ptrdiff_t index) {
546   auto sizeData =
547       object.second.second.cast<DenseElementsAttr>().getValues<uint32_t>();
548   uint32_t startIndex =
549       std::accumulate(sizeData.begin(), sizeData.begin() + index, 0);
550   return object.first.slice(
551       startIndex, *(sizeData.begin() + index),
552       MutableOperandRange::OperandSegment(index, object.second));
553 }
554 
555 //===----------------------------------------------------------------------===//
556 // ResultRange
557 
558 ResultRange::use_range ResultRange::getUses() const {
559   return {use_begin(), use_end()};
560 }
561 ResultRange::use_iterator ResultRange::use_begin() const {
562   return use_iterator(*this);
563 }
564 ResultRange::use_iterator ResultRange::use_end() const {
565   return use_iterator(*this, /*end=*/true);
566 }
567 ResultRange::user_range ResultRange::getUsers() {
568   return {user_begin(), user_end()};
569 }
570 ResultRange::user_iterator ResultRange::user_begin() {
571   return user_iterator(use_begin());
572 }
573 ResultRange::user_iterator ResultRange::user_end() {
574   return user_iterator(use_end());
575 }
576 
577 ResultRange::UseIterator::UseIterator(ResultRange results, bool end)
578     : it(end ? results.end() : results.begin()), endIt(results.end()) {
579   // Only initialize current use if there are results/can be uses.
580   if (it != endIt)
581     skipOverResultsWithNoUsers();
582 }
583 
584 ResultRange::UseIterator &ResultRange::UseIterator::operator++() {
585   // We increment over uses, if we reach the last use then move to next
586   // result.
587   if (use != (*it).use_end())
588     ++use;
589   if (use == (*it).use_end()) {
590     ++it;
591     skipOverResultsWithNoUsers();
592   }
593   return *this;
594 }
595 
596 void ResultRange::UseIterator::skipOverResultsWithNoUsers() {
597   while (it != endIt && (*it).use_empty())
598     ++it;
599 
600   // If we are at the last result, then set use to first use of
601   // first result (sentinel value used for end).
602   if (it == endIt)
603     use = {};
604   else
605     use = (*it).use_begin();
606 }
607 
608 //===----------------------------------------------------------------------===//
609 // ValueRange
610 
611 ValueRange::ValueRange(ArrayRef<Value> values)
612     : ValueRange(values.data(), values.size()) {}
613 ValueRange::ValueRange(OperandRange values)
614     : ValueRange(values.begin().getBase(), values.size()) {}
615 ValueRange::ValueRange(ResultRange values)
616     : ValueRange(values.getBase(), values.size()) {}
617 
618 /// See `llvm::detail::indexed_accessor_range_base` for details.
619 ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
620                                            ptrdiff_t index) {
621   if (const auto *value = owner.dyn_cast<const Value *>())
622     return {value + index};
623   if (auto *operand = owner.dyn_cast<OpOperand *>())
624     return {operand + index};
625   return owner.get<detail::OpResultImpl *>()->getNextResultAtOffset(index);
626 }
627 /// See `llvm::detail::indexed_accessor_range_base` for details.
628 Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
629   if (const auto *value = owner.dyn_cast<const Value *>())
630     return value[index];
631   if (auto *operand = owner.dyn_cast<OpOperand *>())
632     return operand[index].get();
633   return owner.get<detail::OpResultImpl *>()->getNextResultAtOffset(index);
634 }
635 
636 //===----------------------------------------------------------------------===//
637 // Operation Equivalency
638 //===----------------------------------------------------------------------===//
639 
640 llvm::hash_code OperationEquivalence::computeHash(
641     Operation *op, function_ref<llvm::hash_code(Value)> hashOperands,
642     function_ref<llvm::hash_code(Value)> hashResults, Flags flags) {
643   // Hash operations based upon their:
644   //   - Operation Name
645   //   - Attributes
646   //   - Result Types
647   llvm::hash_code hash = llvm::hash_combine(
648       op->getName(), op->getAttrDictionary(), op->getResultTypes());
649 
650   //   - Operands
651   for (Value operand : op->getOperands())
652     hash = llvm::hash_combine(hash, hashOperands(operand));
653   //   - Operands
654   for (Value result : op->getResults())
655     hash = llvm::hash_combine(hash, hashResults(result));
656   return hash;
657 }
658 
659 static bool
660 isRegionEquivalentTo(Region *lhs, Region *rhs,
661                      function_ref<LogicalResult(Value, Value)> mapOperands,
662                      function_ref<LogicalResult(Value, Value)> mapResults,
663                      OperationEquivalence::Flags flags) {
664   DenseMap<Block *, Block *> blocksMap;
665   auto blocksEquivalent = [&](Block &lBlock, Block &rBlock) {
666     // Check block arguments.
667     if (lBlock.getNumArguments() != rBlock.getNumArguments())
668       return false;
669 
670     // Map the two blocks.
671     auto insertion = blocksMap.insert({&lBlock, &rBlock});
672     if (insertion.first->getSecond() != &rBlock)
673       return false;
674 
675     for (auto argPair :
676          llvm::zip(lBlock.getArguments(), rBlock.getArguments())) {
677       Value curArg = std::get<0>(argPair);
678       Value otherArg = std::get<1>(argPair);
679       if (curArg.getType() != otherArg.getType())
680         return false;
681       if (!(flags & OperationEquivalence::IgnoreLocations) &&
682           curArg.getLoc() != otherArg.getLoc())
683         return false;
684       // Check if this value was already mapped to another value.
685       if (failed(mapOperands(curArg, otherArg)))
686         return false;
687     }
688 
689     auto opsEquivalent = [&](Operation &lOp, Operation &rOp) {
690       // Check for op equality (recursively).
691       if (!OperationEquivalence::isEquivalentTo(&lOp, &rOp, mapOperands,
692                                                 mapResults, flags))
693         return false;
694       // Check successor mapping.
695       for (auto successorsPair :
696            llvm::zip(lOp.getSuccessors(), rOp.getSuccessors())) {
697         Block *curSuccessor = std::get<0>(successorsPair);
698         Block *otherSuccessor = std::get<1>(successorsPair);
699         auto insertion = blocksMap.insert({curSuccessor, otherSuccessor});
700         if (insertion.first->getSecond() != otherSuccessor)
701           return false;
702       }
703       return true;
704     };
705     return llvm::all_of_zip(lBlock, rBlock, opsEquivalent);
706   };
707   return llvm::all_of_zip(*lhs, *rhs, blocksEquivalent);
708 }
709 
710 bool OperationEquivalence::isEquivalentTo(
711     Operation *lhs, Operation *rhs,
712     function_ref<LogicalResult(Value, Value)> mapOperands,
713     function_ref<LogicalResult(Value, Value)> mapResults, Flags flags) {
714   if (lhs == rhs)
715     return true;
716 
717   // Compare the operation properties.
718   if (lhs->getName() != rhs->getName() ||
719       lhs->getAttrDictionary() != rhs->getAttrDictionary() ||
720       lhs->getNumRegions() != rhs->getNumRegions() ||
721       lhs->getNumSuccessors() != rhs->getNumSuccessors() ||
722       lhs->getNumOperands() != rhs->getNumOperands() ||
723       lhs->getNumResults() != rhs->getNumResults())
724     return false;
725   if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc())
726     return false;
727 
728   auto checkValueRangeMapping =
729       [](ValueRange lhs, ValueRange rhs,
730          function_ref<LogicalResult(Value, Value)> mapValues) {
731         for (auto operandPair : llvm::zip(lhs, rhs)) {
732           Value curArg = std::get<0>(operandPair);
733           Value otherArg = std::get<1>(operandPair);
734           if (curArg.getType() != otherArg.getType())
735             return false;
736           if (failed(mapValues(curArg, otherArg)))
737             return false;
738         }
739         return true;
740       };
741   // Check mapping of operands and results.
742   if (!checkValueRangeMapping(lhs->getOperands(), rhs->getOperands(),
743                               mapOperands))
744     return false;
745   if (!checkValueRangeMapping(lhs->getResults(), rhs->getResults(), mapResults))
746     return false;
747   for (auto regionPair : llvm::zip(lhs->getRegions(), rhs->getRegions()))
748     if (!isRegionEquivalentTo(&std::get<0>(regionPair),
749                               &std::get<1>(regionPair), mapOperands, mapResults,
750                               flags))
751       return false;
752   return true;
753 }
754