1 //===- OperationSupport.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains out-of-line implementations of the support types that
10 // Operation and related classes build on top of.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/IR/OperationSupport.h"
15 #include "mlir/IR/BuiltinAttributes.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "llvm/ADT/BitVector.h"
19 #include <numeric>
20
21 using namespace mlir;
22
23 //===----------------------------------------------------------------------===//
24 // NamedAttrList
25 //===----------------------------------------------------------------------===//
26
NamedAttrList(ArrayRef<NamedAttribute> attributes)27 NamedAttrList::NamedAttrList(ArrayRef<NamedAttribute> attributes) {
28 assign(attributes.begin(), attributes.end());
29 }
30
NamedAttrList(DictionaryAttr attributes)31 NamedAttrList::NamedAttrList(DictionaryAttr attributes)
32 : NamedAttrList(attributes ? attributes.getValue()
33 : ArrayRef<NamedAttribute>()) {
34 dictionarySorted.setPointerAndInt(attributes, true);
35 }
36
NamedAttrList(const_iterator inStart,const_iterator inEnd)37 NamedAttrList::NamedAttrList(const_iterator inStart, const_iterator inEnd) {
38 assign(inStart, inEnd);
39 }
40
getAttrs() const41 ArrayRef<NamedAttribute> NamedAttrList::getAttrs() const { return attrs; }
42
findDuplicate() const43 Optional<NamedAttribute> NamedAttrList::findDuplicate() const {
44 Optional<NamedAttribute> duplicate =
45 DictionaryAttr::findDuplicate(attrs, isSorted());
46 // DictionaryAttr::findDuplicate will sort the list, so reset the sorted
47 // state.
48 if (!isSorted())
49 dictionarySorted.setPointerAndInt(nullptr, true);
50 return duplicate;
51 }
52
getDictionary(MLIRContext * context) const53 DictionaryAttr NamedAttrList::getDictionary(MLIRContext *context) const {
54 if (!isSorted()) {
55 DictionaryAttr::sortInPlace(attrs);
56 dictionarySorted.setPointerAndInt(nullptr, true);
57 }
58 if (!dictionarySorted.getPointer())
59 dictionarySorted.setPointer(DictionaryAttr::getWithSorted(context, attrs));
60 return dictionarySorted.getPointer().cast<DictionaryAttr>();
61 }
62
63 /// Add an attribute with the specified name.
append(StringRef name,Attribute attr)64 void NamedAttrList::append(StringRef name, Attribute attr) {
65 append(StringAttr::get(attr.getContext(), name), attr);
66 }
67
68 /// Replaces the attributes with new list of attributes.
assign(const_iterator inStart,const_iterator inEnd)69 void NamedAttrList::assign(const_iterator inStart, const_iterator inEnd) {
70 DictionaryAttr::sort(ArrayRef<NamedAttribute>{inStart, inEnd}, attrs);
71 dictionarySorted.setPointerAndInt(nullptr, true);
72 }
73
push_back(NamedAttribute newAttribute)74 void NamedAttrList::push_back(NamedAttribute newAttribute) {
75 if (isSorted())
76 dictionarySorted.setInt(attrs.empty() || attrs.back() < newAttribute);
77 dictionarySorted.setPointer(nullptr);
78 attrs.push_back(newAttribute);
79 }
80
81 /// Return the specified attribute if present, null otherwise.
get(StringRef name) const82 Attribute NamedAttrList::get(StringRef name) const {
83 auto it = findAttr(*this, name);
84 return it.second ? it.first->getValue() : Attribute();
85 }
get(StringAttr name) const86 Attribute NamedAttrList::get(StringAttr name) const {
87 auto it = findAttr(*this, name);
88 return it.second ? it.first->getValue() : Attribute();
89 }
90
91 /// Return the specified named attribute if present, None otherwise.
getNamed(StringRef name) const92 Optional<NamedAttribute> NamedAttrList::getNamed(StringRef name) const {
93 auto it = findAttr(*this, name);
94 return it.second ? *it.first : Optional<NamedAttribute>();
95 }
getNamed(StringAttr name) const96 Optional<NamedAttribute> NamedAttrList::getNamed(StringAttr name) const {
97 auto it = findAttr(*this, name);
98 return it.second ? *it.first : Optional<NamedAttribute>();
99 }
100
101 /// If the an attribute exists with the specified name, change it to the new
102 /// value. Otherwise, add a new attribute with the specified name/value.
set(StringAttr name,Attribute value)103 Attribute NamedAttrList::set(StringAttr name, Attribute value) {
104 assert(value && "attributes may never be null");
105
106 // Look for an existing attribute with the given name, and set its value
107 // in-place. Return the previous value of the attribute, if there was one.
108 auto it = findAttr(*this, name);
109 if (it.second) {
110 // Update the existing attribute by swapping out the old value for the new
111 // value. Return the old value.
112 Attribute oldValue = it.first->getValue();
113 if (it.first->getValue() != value) {
114 it.first->setValue(value);
115
116 // If the attributes have changed, the dictionary is invalidated.
117 dictionarySorted.setPointer(nullptr);
118 }
119 return oldValue;
120 }
121 // Perform a string lookup to insert the new attribute into its sorted
122 // position.
123 if (isSorted())
124 it = findAttr(*this, name.strref());
125 attrs.insert(it.first, {name, value});
126 // Invalidate the dictionary. Return null as there was no previous value.
127 dictionarySorted.setPointer(nullptr);
128 return Attribute();
129 }
130
set(StringRef name,Attribute value)131 Attribute NamedAttrList::set(StringRef name, Attribute value) {
132 assert(value && "attributes may never be null");
133 return set(mlir::StringAttr::get(value.getContext(), name), value);
134 }
135
136 Attribute
eraseImpl(SmallVectorImpl<NamedAttribute>::iterator it)137 NamedAttrList::eraseImpl(SmallVectorImpl<NamedAttribute>::iterator it) {
138 // Erasing does not affect the sorted property.
139 Attribute attr = it->getValue();
140 attrs.erase(it);
141 dictionarySorted.setPointer(nullptr);
142 return attr;
143 }
144
erase(StringAttr name)145 Attribute NamedAttrList::erase(StringAttr name) {
146 auto it = findAttr(*this, name);
147 return it.second ? eraseImpl(it.first) : Attribute();
148 }
149
erase(StringRef name)150 Attribute NamedAttrList::erase(StringRef name) {
151 auto it = findAttr(*this, name);
152 return it.second ? eraseImpl(it.first) : Attribute();
153 }
154
155 NamedAttrList &
operator =(const SmallVectorImpl<NamedAttribute> & rhs)156 NamedAttrList::operator=(const SmallVectorImpl<NamedAttribute> &rhs) {
157 assign(rhs.begin(), rhs.end());
158 return *this;
159 }
160
operator ArrayRef<NamedAttribute>() const161 NamedAttrList::operator ArrayRef<NamedAttribute>() const { return attrs; }
162
163 //===----------------------------------------------------------------------===//
164 // OperationState
165 //===----------------------------------------------------------------------===//
166
OperationState(Location location,StringRef name)167 OperationState::OperationState(Location location, StringRef name)
168 : location(location), name(name, location->getContext()) {}
169
OperationState(Location location,OperationName name)170 OperationState::OperationState(Location location, OperationName name)
171 : location(location), name(name) {}
172
OperationState(Location location,OperationName name,ValueRange operands,TypeRange types,ArrayRef<NamedAttribute> attributes,BlockRange successors,MutableArrayRef<std::unique_ptr<Region>> regions)173 OperationState::OperationState(Location location, OperationName name,
174 ValueRange operands, TypeRange types,
175 ArrayRef<NamedAttribute> attributes,
176 BlockRange successors,
177 MutableArrayRef<std::unique_ptr<Region>> regions)
178 : location(location), name(name),
179 operands(operands.begin(), operands.end()),
180 types(types.begin(), types.end()),
181 attributes(attributes.begin(), attributes.end()),
182 successors(successors.begin(), successors.end()) {
183 for (std::unique_ptr<Region> &r : regions)
184 this->regions.push_back(std::move(r));
185 }
OperationState(Location location,StringRef name,ValueRange operands,TypeRange types,ArrayRef<NamedAttribute> attributes,BlockRange successors,MutableArrayRef<std::unique_ptr<Region>> regions)186 OperationState::OperationState(Location location, StringRef name,
187 ValueRange operands, TypeRange types,
188 ArrayRef<NamedAttribute> attributes,
189 BlockRange successors,
190 MutableArrayRef<std::unique_ptr<Region>> regions)
191 : OperationState(location, OperationName(name, location.getContext()),
192 operands, types, attributes, successors, regions) {}
193
addOperands(ValueRange newOperands)194 void OperationState::addOperands(ValueRange newOperands) {
195 operands.append(newOperands.begin(), newOperands.end());
196 }
197
addSuccessors(BlockRange newSuccessors)198 void OperationState::addSuccessors(BlockRange newSuccessors) {
199 successors.append(newSuccessors.begin(), newSuccessors.end());
200 }
201
addRegion()202 Region *OperationState::addRegion() {
203 regions.emplace_back(new Region);
204 return regions.back().get();
205 }
206
addRegion(std::unique_ptr<Region> && region)207 void OperationState::addRegion(std::unique_ptr<Region> &®ion) {
208 regions.push_back(std::move(region));
209 }
210
addRegions(MutableArrayRef<std::unique_ptr<Region>> regions)211 void OperationState::addRegions(
212 MutableArrayRef<std::unique_ptr<Region>> regions) {
213 for (std::unique_ptr<Region> ®ion : regions)
214 addRegion(std::move(region));
215 }
216
217 //===----------------------------------------------------------------------===//
218 // OperandStorage
219 //===----------------------------------------------------------------------===//
220
OperandStorage(Operation * owner,OpOperand * trailingOperands,ValueRange values)221 detail::OperandStorage::OperandStorage(Operation *owner,
222 OpOperand *trailingOperands,
223 ValueRange values)
224 : isStorageDynamic(false), operandStorage(trailingOperands) {
225 numOperands = capacity = values.size();
226 for (unsigned i = 0; i < numOperands; ++i)
227 new (&operandStorage[i]) OpOperand(owner, values[i]);
228 }
229
~OperandStorage()230 detail::OperandStorage::~OperandStorage() {
231 for (auto &operand : getOperands())
232 operand.~OpOperand();
233
234 // If the storage is dynamic, deallocate it.
235 if (isStorageDynamic)
236 free(operandStorage);
237 }
238
239 /// Replace the operands contained in the storage with the ones provided in
240 /// 'values'.
setOperands(Operation * owner,ValueRange values)241 void detail::OperandStorage::setOperands(Operation *owner, ValueRange values) {
242 MutableArrayRef<OpOperand> storageOperands = resize(owner, values.size());
243 for (unsigned i = 0, e = values.size(); i != e; ++i)
244 storageOperands[i].set(values[i]);
245 }
246
247 /// Replace the operands beginning at 'start' and ending at 'start' + 'length'
248 /// with the ones provided in 'operands'. 'operands' may be smaller or larger
249 /// than the range pointed to by 'start'+'length'.
setOperands(Operation * owner,unsigned start,unsigned length,ValueRange operands)250 void detail::OperandStorage::setOperands(Operation *owner, unsigned start,
251 unsigned length, ValueRange operands) {
252 // If the new size is the same, we can update inplace.
253 unsigned newSize = operands.size();
254 if (newSize == length) {
255 MutableArrayRef<OpOperand> storageOperands = getOperands();
256 for (unsigned i = 0, e = length; i != e; ++i)
257 storageOperands[start + i].set(operands[i]);
258 return;
259 }
260 // If the new size is greater, remove the extra operands and set the rest
261 // inplace.
262 if (newSize < length) {
263 eraseOperands(start + operands.size(), length - newSize);
264 setOperands(owner, start, newSize, operands);
265 return;
266 }
267 // Otherwise, the new size is greater so we need to grow the storage.
268 auto storageOperands = resize(owner, size() + (newSize - length));
269
270 // Shift operands to the right to make space for the new operands.
271 unsigned rotateSize = storageOperands.size() - (start + length);
272 auto rbegin = storageOperands.rbegin();
273 std::rotate(rbegin, std::next(rbegin, newSize - length), rbegin + rotateSize);
274
275 // Update the operands inplace.
276 for (unsigned i = 0, e = operands.size(); i != e; ++i)
277 storageOperands[start + i].set(operands[i]);
278 }
279
280 /// Erase an operand held by the storage.
eraseOperands(unsigned start,unsigned length)281 void detail::OperandStorage::eraseOperands(unsigned start, unsigned length) {
282 MutableArrayRef<OpOperand> operands = getOperands();
283 assert((start + length) <= operands.size());
284 numOperands -= length;
285
286 // Shift all operands down if the operand to remove is not at the end.
287 if (start != numOperands) {
288 auto *indexIt = std::next(operands.begin(), start);
289 std::rotate(indexIt, std::next(indexIt, length), operands.end());
290 }
291 for (unsigned i = 0; i != length; ++i)
292 operands[numOperands + i].~OpOperand();
293 }
294
eraseOperands(const BitVector & eraseIndices)295 void detail::OperandStorage::eraseOperands(const BitVector &eraseIndices) {
296 MutableArrayRef<OpOperand> operands = getOperands();
297 assert(eraseIndices.size() == operands.size());
298
299 // Check that at least one operand is erased.
300 int firstErasedIndice = eraseIndices.find_first();
301 if (firstErasedIndice == -1)
302 return;
303
304 // Shift all of the removed operands to the end, and destroy them.
305 numOperands = firstErasedIndice;
306 for (unsigned i = firstErasedIndice + 1, e = operands.size(); i < e; ++i)
307 if (!eraseIndices.test(i))
308 operands[numOperands++] = std::move(operands[i]);
309 for (OpOperand &operand : operands.drop_front(numOperands))
310 operand.~OpOperand();
311 }
312
313 /// Resize the storage to the given size. Returns the array containing the new
314 /// operands.
resize(Operation * owner,unsigned newSize)315 MutableArrayRef<OpOperand> detail::OperandStorage::resize(Operation *owner,
316 unsigned newSize) {
317 // If the number of operands is less than or equal to the current amount, we
318 // can just update in place.
319 MutableArrayRef<OpOperand> origOperands = getOperands();
320 if (newSize <= numOperands) {
321 // If the number of new size is less than the current, remove any extra
322 // operands.
323 for (unsigned i = newSize; i != numOperands; ++i)
324 origOperands[i].~OpOperand();
325 numOperands = newSize;
326 return origOperands.take_front(newSize);
327 }
328
329 // If the new size is within the original inline capacity, grow inplace.
330 if (newSize <= capacity) {
331 OpOperand *opBegin = origOperands.data();
332 for (unsigned e = newSize; numOperands != e; ++numOperands)
333 new (&opBegin[numOperands]) OpOperand(owner);
334 return MutableArrayRef<OpOperand>(opBegin, newSize);
335 }
336
337 // Otherwise, we need to allocate a new storage.
338 unsigned newCapacity =
339 std::max(unsigned(llvm::NextPowerOf2(capacity + 2)), newSize);
340 OpOperand *newOperandStorage =
341 reinterpret_cast<OpOperand *>(malloc(sizeof(OpOperand) * newCapacity));
342
343 // Move the current operands to the new storage.
344 MutableArrayRef<OpOperand> newOperands(newOperandStorage, newSize);
345 std::uninitialized_copy(std::make_move_iterator(origOperands.begin()),
346 std::make_move_iterator(origOperands.end()),
347 newOperands.begin());
348
349 // Destroy the original operands.
350 for (auto &operand : origOperands)
351 operand.~OpOperand();
352
353 // Initialize any new operands.
354 for (unsigned e = newSize; numOperands != e; ++numOperands)
355 new (&newOperands[numOperands]) OpOperand(owner);
356
357 // If the current storage is dynamic, free it.
358 if (isStorageDynamic)
359 free(operandStorage);
360
361 // Update the storage representation to use the new dynamic storage.
362 operandStorage = newOperandStorage;
363 capacity = newCapacity;
364 isStorageDynamic = true;
365 return newOperands;
366 }
367
368 //===----------------------------------------------------------------------===//
369 // Operation Value-Iterators
370 //===----------------------------------------------------------------------===//
371
372 //===----------------------------------------------------------------------===//
373 // OperandRange
374
getBeginOperandIndex() const375 unsigned OperandRange::getBeginOperandIndex() const {
376 assert(!empty() && "range must not be empty");
377 return base->getOperandNumber();
378 }
379
split(ElementsAttr segmentSizes) const380 OperandRangeRange OperandRange::split(ElementsAttr segmentSizes) const {
381 return OperandRangeRange(*this, segmentSizes);
382 }
383
384 //===----------------------------------------------------------------------===//
385 // OperandRangeRange
386
OperandRangeRange(OperandRange operands,Attribute operandSegments)387 OperandRangeRange::OperandRangeRange(OperandRange operands,
388 Attribute operandSegments)
389 : OperandRangeRange(OwnerT(operands.getBase(), operandSegments), 0,
390 operandSegments.cast<DenseElementsAttr>().size()) {}
391
join() const392 OperandRange OperandRangeRange::join() const {
393 const OwnerT &owner = getBase();
394 auto sizeData = owner.second.cast<DenseElementsAttr>().getValues<uint32_t>();
395 return OperandRange(owner.first,
396 std::accumulate(sizeData.begin(), sizeData.end(), 0));
397 }
398
dereference(const OwnerT & object,ptrdiff_t index)399 OperandRange OperandRangeRange::dereference(const OwnerT &object,
400 ptrdiff_t index) {
401 auto sizeData = object.second.cast<DenseElementsAttr>().getValues<uint32_t>();
402 uint32_t startIndex =
403 std::accumulate(sizeData.begin(), sizeData.begin() + index, 0);
404 return OperandRange(object.first + startIndex, *(sizeData.begin() + index));
405 }
406
407 //===----------------------------------------------------------------------===//
408 // MutableOperandRange
409
410 /// Construct a new mutable range from the given operand, operand start index,
411 /// and range length.
MutableOperandRange(Operation * owner,unsigned start,unsigned length,ArrayRef<OperandSegment> operandSegments)412 MutableOperandRange::MutableOperandRange(
413 Operation *owner, unsigned start, unsigned length,
414 ArrayRef<OperandSegment> operandSegments)
415 : owner(owner), start(start), length(length),
416 operandSegments(operandSegments.begin(), operandSegments.end()) {
417 assert((start + length) <= owner->getNumOperands() && "invalid range");
418 }
MutableOperandRange(Operation * owner)419 MutableOperandRange::MutableOperandRange(Operation *owner)
420 : MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {}
421
422 /// Slice this range into a sub range, with the additional operand segment.
423 MutableOperandRange
slice(unsigned subStart,unsigned subLen,Optional<OperandSegment> segment) const424 MutableOperandRange::slice(unsigned subStart, unsigned subLen,
425 Optional<OperandSegment> segment) const {
426 assert((subStart + subLen) <= length && "invalid sub-range");
427 MutableOperandRange subSlice(owner, start + subStart, subLen,
428 operandSegments);
429 if (segment)
430 subSlice.operandSegments.push_back(*segment);
431 return subSlice;
432 }
433
434 /// Append the given values to the range.
append(ValueRange values)435 void MutableOperandRange::append(ValueRange values) {
436 if (values.empty())
437 return;
438 owner->insertOperands(start + length, values);
439 updateLength(length + values.size());
440 }
441
442 /// Assign this range to the given values.
assign(ValueRange values)443 void MutableOperandRange::assign(ValueRange values) {
444 owner->setOperands(start, length, values);
445 if (length != values.size())
446 updateLength(/*newLength=*/values.size());
447 }
448
449 /// Assign the range to the given value.
assign(Value value)450 void MutableOperandRange::assign(Value value) {
451 if (length == 1) {
452 owner->setOperand(start, value);
453 } else {
454 owner->setOperands(start, length, value);
455 updateLength(/*newLength=*/1);
456 }
457 }
458
459 /// Erase the operands within the given sub-range.
erase(unsigned subStart,unsigned subLen)460 void MutableOperandRange::erase(unsigned subStart, unsigned subLen) {
461 assert((subStart + subLen) <= length && "invalid sub-range");
462 if (length == 0)
463 return;
464 owner->eraseOperands(start + subStart, subLen);
465 updateLength(length - subLen);
466 }
467
468 /// Clear this range and erase all of the operands.
clear()469 void MutableOperandRange::clear() {
470 if (length != 0) {
471 owner->eraseOperands(start, length);
472 updateLength(/*newLength=*/0);
473 }
474 }
475
476 /// Allow implicit conversion to an OperandRange.
operator OperandRange() const477 MutableOperandRange::operator OperandRange() const {
478 return owner->getOperands().slice(start, length);
479 }
480
481 MutableOperandRangeRange
split(NamedAttribute segmentSizes) const482 MutableOperandRange::split(NamedAttribute segmentSizes) const {
483 return MutableOperandRangeRange(*this, segmentSizes);
484 }
485
486 /// Update the length of this range to the one provided.
updateLength(unsigned newLength)487 void MutableOperandRange::updateLength(unsigned newLength) {
488 int32_t diff = int32_t(newLength) - int32_t(length);
489 length = newLength;
490
491 // Update any of the provided segment attributes.
492 for (OperandSegment &segment : operandSegments) {
493 auto attr = segment.second.getValue().cast<DenseIntElementsAttr>();
494 SmallVector<int32_t, 8> segments(attr.getValues<int32_t>());
495 segments[segment.first] += diff;
496 segment.second.setValue(
497 DenseIntElementsAttr::get(attr.getType(), segments));
498 owner->setAttr(segment.second.getName(), segment.second.getValue());
499 }
500 }
501
502 //===----------------------------------------------------------------------===//
503 // MutableOperandRangeRange
504
MutableOperandRangeRange(const MutableOperandRange & operands,NamedAttribute operandSegmentAttr)505 MutableOperandRangeRange::MutableOperandRangeRange(
506 const MutableOperandRange &operands, NamedAttribute operandSegmentAttr)
507 : MutableOperandRangeRange(
508 OwnerT(operands, operandSegmentAttr), 0,
509 operandSegmentAttr.getValue().cast<DenseElementsAttr>().size()) {}
510
join() const511 MutableOperandRange MutableOperandRangeRange::join() const {
512 return getBase().first;
513 }
514
operator OperandRangeRange() const515 MutableOperandRangeRange::operator OperandRangeRange() const {
516 return OperandRangeRange(
517 getBase().first, getBase().second.getValue().cast<DenseElementsAttr>());
518 }
519
dereference(const OwnerT & object,ptrdiff_t index)520 MutableOperandRange MutableOperandRangeRange::dereference(const OwnerT &object,
521 ptrdiff_t index) {
522 auto sizeData =
523 object.second.getValue().cast<DenseElementsAttr>().getValues<uint32_t>();
524 uint32_t startIndex =
525 std::accumulate(sizeData.begin(), sizeData.begin() + index, 0);
526 return object.first.slice(
527 startIndex, *(sizeData.begin() + index),
528 MutableOperandRange::OperandSegment(index, object.second));
529 }
530
531 //===----------------------------------------------------------------------===//
532 // ResultRange
533
ResultRange(OpResult result)534 ResultRange::ResultRange(OpResult result)
535 : ResultRange(static_cast<detail::OpResultImpl *>(Value(result).getImpl()),
536 1) {}
537
getUses() const538 ResultRange::use_range ResultRange::getUses() const {
539 return {use_begin(), use_end()};
540 }
use_begin() const541 ResultRange::use_iterator ResultRange::use_begin() const {
542 return use_iterator(*this);
543 }
use_end() const544 ResultRange::use_iterator ResultRange::use_end() const {
545 return use_iterator(*this, /*end=*/true);
546 }
getUsers()547 ResultRange::user_range ResultRange::getUsers() {
548 return {user_begin(), user_end()};
549 }
user_begin()550 ResultRange::user_iterator ResultRange::user_begin() {
551 return user_iterator(use_begin());
552 }
user_end()553 ResultRange::user_iterator ResultRange::user_end() {
554 return user_iterator(use_end());
555 }
556
UseIterator(ResultRange results,bool end)557 ResultRange::UseIterator::UseIterator(ResultRange results, bool end)
558 : it(end ? results.end() : results.begin()), endIt(results.end()) {
559 // Only initialize current use if there are results/can be uses.
560 if (it != endIt)
561 skipOverResultsWithNoUsers();
562 }
563
operator ++()564 ResultRange::UseIterator &ResultRange::UseIterator::operator++() {
565 // We increment over uses, if we reach the last use then move to next
566 // result.
567 if (use != (*it).use_end())
568 ++use;
569 if (use == (*it).use_end()) {
570 ++it;
571 skipOverResultsWithNoUsers();
572 }
573 return *this;
574 }
575
skipOverResultsWithNoUsers()576 void ResultRange::UseIterator::skipOverResultsWithNoUsers() {
577 while (it != endIt && (*it).use_empty())
578 ++it;
579
580 // If we are at the last result, then set use to first use of
581 // first result (sentinel value used for end).
582 if (it == endIt)
583 use = {};
584 else
585 use = (*it).use_begin();
586 }
587
replaceAllUsesWith(Operation * op)588 void ResultRange::replaceAllUsesWith(Operation *op) {
589 replaceAllUsesWith(op->getResults());
590 }
591
592 //===----------------------------------------------------------------------===//
593 // ValueRange
594
ValueRange(ArrayRef<Value> values)595 ValueRange::ValueRange(ArrayRef<Value> values)
596 : ValueRange(values.data(), values.size()) {}
ValueRange(OperandRange values)597 ValueRange::ValueRange(OperandRange values)
598 : ValueRange(values.begin().getBase(), values.size()) {}
ValueRange(ResultRange values)599 ValueRange::ValueRange(ResultRange values)
600 : ValueRange(values.getBase(), values.size()) {}
601
602 /// See `llvm::detail::indexed_accessor_range_base` for details.
offset_base(const OwnerT & owner,ptrdiff_t index)603 ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
604 ptrdiff_t index) {
605 if (const auto *value = owner.dyn_cast<const Value *>())
606 return {value + index};
607 if (auto *operand = owner.dyn_cast<OpOperand *>())
608 return {operand + index};
609 return owner.get<detail::OpResultImpl *>()->getNextResultAtOffset(index);
610 }
611 /// See `llvm::detail::indexed_accessor_range_base` for details.
dereference_iterator(const OwnerT & owner,ptrdiff_t index)612 Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
613 if (const auto *value = owner.dyn_cast<const Value *>())
614 return value[index];
615 if (auto *operand = owner.dyn_cast<OpOperand *>())
616 return operand[index].get();
617 return owner.get<detail::OpResultImpl *>()->getNextResultAtOffset(index);
618 }
619
620 //===----------------------------------------------------------------------===//
621 // Operation Equivalency
622 //===----------------------------------------------------------------------===//
623
computeHash(Operation * op,function_ref<llvm::hash_code (Value)> hashOperands,function_ref<llvm::hash_code (Value)> hashResults,Flags flags)624 llvm::hash_code OperationEquivalence::computeHash(
625 Operation *op, function_ref<llvm::hash_code(Value)> hashOperands,
626 function_ref<llvm::hash_code(Value)> hashResults, Flags flags) {
627 // Hash operations based upon their:
628 // - Operation Name
629 // - Attributes
630 // - Result Types
631 llvm::hash_code hash = llvm::hash_combine(
632 op->getName(), op->getAttrDictionary(), op->getResultTypes());
633
634 // - Operands
635 ValueRange operands = op->getOperands();
636 SmallVector<Value> operandStorage;
637 if (op->hasTrait<mlir::OpTrait::IsCommutative>()) {
638 operandStorage.append(operands.begin(), operands.end());
639 llvm::sort(operandStorage, [](Value a, Value b) -> bool {
640 return a.getAsOpaquePointer() < b.getAsOpaquePointer();
641 });
642 operands = operandStorage;
643 }
644 for (Value operand : operands)
645 hash = llvm::hash_combine(hash, hashOperands(operand));
646
647 // - Operands
648 for (Value result : op->getResults())
649 hash = llvm::hash_combine(hash, hashResults(result));
650 return hash;
651 }
652
653 static bool
isRegionEquivalentTo(Region * lhs,Region * rhs,function_ref<LogicalResult (Value,Value)> mapOperands,function_ref<LogicalResult (Value,Value)> mapResults,OperationEquivalence::Flags flags)654 isRegionEquivalentTo(Region *lhs, Region *rhs,
655 function_ref<LogicalResult(Value, Value)> mapOperands,
656 function_ref<LogicalResult(Value, Value)> mapResults,
657 OperationEquivalence::Flags flags) {
658 DenseMap<Block *, Block *> blocksMap;
659 auto blocksEquivalent = [&](Block &lBlock, Block &rBlock) {
660 // Check block arguments.
661 if (lBlock.getNumArguments() != rBlock.getNumArguments())
662 return false;
663
664 // Map the two blocks.
665 auto insertion = blocksMap.insert({&lBlock, &rBlock});
666 if (insertion.first->getSecond() != &rBlock)
667 return false;
668
669 for (auto argPair :
670 llvm::zip(lBlock.getArguments(), rBlock.getArguments())) {
671 Value curArg = std::get<0>(argPair);
672 Value otherArg = std::get<1>(argPair);
673 if (curArg.getType() != otherArg.getType())
674 return false;
675 if (!(flags & OperationEquivalence::IgnoreLocations) &&
676 curArg.getLoc() != otherArg.getLoc())
677 return false;
678 // Check if this value was already mapped to another value.
679 if (failed(mapOperands(curArg, otherArg)))
680 return false;
681 }
682
683 auto opsEquivalent = [&](Operation &lOp, Operation &rOp) {
684 // Check for op equality (recursively).
685 if (!OperationEquivalence::isEquivalentTo(&lOp, &rOp, mapOperands,
686 mapResults, flags))
687 return false;
688 // Check successor mapping.
689 for (auto successorsPair :
690 llvm::zip(lOp.getSuccessors(), rOp.getSuccessors())) {
691 Block *curSuccessor = std::get<0>(successorsPair);
692 Block *otherSuccessor = std::get<1>(successorsPair);
693 auto insertion = blocksMap.insert({curSuccessor, otherSuccessor});
694 if (insertion.first->getSecond() != otherSuccessor)
695 return false;
696 }
697 return true;
698 };
699 return llvm::all_of_zip(lBlock, rBlock, opsEquivalent);
700 };
701 return llvm::all_of_zip(*lhs, *rhs, blocksEquivalent);
702 }
703
isEquivalentTo(Operation * lhs,Operation * rhs,function_ref<LogicalResult (Value,Value)> mapOperands,function_ref<LogicalResult (Value,Value)> mapResults,Flags flags)704 bool OperationEquivalence::isEquivalentTo(
705 Operation *lhs, Operation *rhs,
706 function_ref<LogicalResult(Value, Value)> mapOperands,
707 function_ref<LogicalResult(Value, Value)> mapResults, Flags flags) {
708 if (lhs == rhs)
709 return true;
710
711 // Compare the operation properties.
712 if (lhs->getName() != rhs->getName() ||
713 lhs->getAttrDictionary() != rhs->getAttrDictionary() ||
714 lhs->getNumRegions() != rhs->getNumRegions() ||
715 lhs->getNumSuccessors() != rhs->getNumSuccessors() ||
716 lhs->getNumOperands() != rhs->getNumOperands() ||
717 lhs->getNumResults() != rhs->getNumResults())
718 return false;
719 if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc())
720 return false;
721
722 ValueRange lhsOperands = lhs->getOperands(), rhsOperands = rhs->getOperands();
723 SmallVector<Value> lhsOperandStorage, rhsOperandStorage;
724 if (lhs->hasTrait<mlir::OpTrait::IsCommutative>()) {
725 lhsOperandStorage.append(lhsOperands.begin(), lhsOperands.end());
726 llvm::sort(lhsOperandStorage, [](Value a, Value b) -> bool {
727 return a.getAsOpaquePointer() < b.getAsOpaquePointer();
728 });
729 lhsOperands = lhsOperandStorage;
730
731 rhsOperandStorage.append(rhsOperands.begin(), rhsOperands.end());
732 llvm::sort(rhsOperandStorage, [](Value a, Value b) -> bool {
733 return a.getAsOpaquePointer() < b.getAsOpaquePointer();
734 });
735 rhsOperands = rhsOperandStorage;
736 }
737 auto checkValueRangeMapping =
738 [](ValueRange lhs, ValueRange rhs,
739 function_ref<LogicalResult(Value, Value)> mapValues) {
740 for (auto operandPair : llvm::zip(lhs, rhs)) {
741 Value curArg = std::get<0>(operandPair);
742 Value otherArg = std::get<1>(operandPair);
743 if (curArg.getType() != otherArg.getType())
744 return false;
745 if (failed(mapValues(curArg, otherArg)))
746 return false;
747 }
748 return true;
749 };
750 // Check mapping of operands and results.
751 if (!checkValueRangeMapping(lhsOperands, rhsOperands, mapOperands))
752 return false;
753 if (!checkValueRangeMapping(lhs->getResults(), rhs->getResults(), mapResults))
754 return false;
755 for (auto regionPair : llvm::zip(lhs->getRegions(), rhs->getRegions()))
756 if (!isRegionEquivalentTo(&std::get<0>(regionPair),
757 &std::get<1>(regionPair), mapOperands, mapResults,
758 flags))
759 return false;
760 return true;
761 }
762