1 //===- OperationSupport.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains out-of-line implementations of the support types that
10 // Operation and related classes build on top of.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/IR/OperationSupport.h"
15 #include "mlir/IR/Block.h"
16 #include "mlir/IR/OpDefinition.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/IR/StandardTypes.h"
19 using namespace mlir;
20 
21 //===----------------------------------------------------------------------===//
22 // OperationState
23 //===----------------------------------------------------------------------===//
24 
25 OperationState::OperationState(Location location, StringRef name)
26     : location(location), name(name, location->getContext()) {}
27 
28 OperationState::OperationState(Location location, OperationName name)
29     : location(location), name(name) {}
30 
31 OperationState::OperationState(Location location, StringRef name,
32                                ValueRange operands, ArrayRef<Type> types,
33                                ArrayRef<NamedAttribute> attributes,
34                                ArrayRef<Block *> successors,
35                                MutableArrayRef<std::unique_ptr<Region>> regions)
36     : location(location), name(name, location->getContext()),
37       operands(operands.begin(), operands.end()),
38       types(types.begin(), types.end()),
39       attributes(attributes.begin(), attributes.end()),
40       successors(successors.begin(), successors.end()) {
41   for (std::unique_ptr<Region> &r : regions)
42     this->regions.push_back(std::move(r));
43 }
44 
45 void OperationState::addOperands(ValueRange newOperands) {
46   operands.append(newOperands.begin(), newOperands.end());
47 }
48 
49 void OperationState::addSuccessors(SuccessorRange newSuccessors) {
50   successors.append(newSuccessors.begin(), newSuccessors.end());
51 }
52 
53 Region *OperationState::addRegion() {
54   regions.emplace_back(new Region);
55   return regions.back().get();
56 }
57 
58 void OperationState::addRegion(std::unique_ptr<Region> &&region) {
59   regions.push_back(std::move(region));
60 }
61 
62 //===----------------------------------------------------------------------===//
63 // OperandStorage
64 //===----------------------------------------------------------------------===//
65 
66 detail::OperandStorage::OperandStorage(Operation *owner, ValueRange values)
67     : representation(0) {
68   auto &inlineStorage = getInlineStorage();
69   inlineStorage.numOperands = inlineStorage.capacity = values.size();
70   auto *operandPtrBegin = getTrailingObjects<OpOperand>();
71   for (unsigned i = 0, e = inlineStorage.numOperands; i < e; ++i)
72     new (&operandPtrBegin[i]) OpOperand(owner, values[i]);
73 }
74 
75 detail::OperandStorage::~OperandStorage() {
76   // Destruct the current storage container.
77   if (isDynamicStorage()) {
78     TrailingOperandStorage &storage = getDynamicStorage();
79     storage.~TrailingOperandStorage();
80     free(&storage);
81   } else {
82     getInlineStorage().~TrailingOperandStorage();
83   }
84 }
85 
86 /// Replace the operands contained in the storage with the ones provided in
87 /// 'values'.
88 void detail::OperandStorage::setOperands(Operation *owner, ValueRange values) {
89   MutableArrayRef<OpOperand> storageOperands = resize(owner, values.size());
90   for (unsigned i = 0, e = values.size(); i != e; ++i)
91     storageOperands[i].set(values[i]);
92 }
93 
94 /// Replace the operands beginning at 'start' and ending at 'start' + 'length'
95 /// with the ones provided in 'operands'. 'operands' may be smaller or larger
96 /// than the range pointed to by 'start'+'length'.
97 void detail::OperandStorage::setOperands(Operation *owner, unsigned start,
98                                          unsigned length, ValueRange operands) {
99   // If the new size is the same, we can update inplace.
100   unsigned newSize = operands.size();
101   if (newSize == length) {
102     MutableArrayRef<OpOperand> storageOperands = getOperands();
103     for (unsigned i = 0, e = length; i != e; ++i)
104       storageOperands[start + i].set(operands[i]);
105     return;
106   }
107   // If the new size is greater, remove the extra operands and set the rest
108   // inplace.
109   if (newSize < length) {
110     eraseOperands(start + operands.size(), length - newSize);
111     setOperands(owner, start, newSize, operands);
112     return;
113   }
114   // Otherwise, the new size is greater so we need to grow the storage.
115   auto storageOperands = resize(owner, size() + (newSize - length));
116 
117   // Shift operands to the right to make space for the new operands.
118   unsigned rotateSize = storageOperands.size() - (start + length);
119   auto rbegin = storageOperands.rbegin();
120   std::rotate(rbegin, std::next(rbegin, newSize - length), rbegin + rotateSize);
121 
122   // Update the operands inplace.
123   for (unsigned i = 0, e = operands.size(); i != e; ++i)
124     storageOperands[start + i].set(operands[i]);
125 }
126 
127 /// Erase an operand held by the storage.
128 void detail::OperandStorage::eraseOperands(unsigned start, unsigned length) {
129   TrailingOperandStorage &storage = getStorage();
130   MutableArrayRef<OpOperand> operands = storage.getOperands();
131   assert((start + length) <= operands.size());
132   storage.numOperands -= length;
133 
134   // Shift all operands down if the operand to remove is not at the end.
135   if (start != storage.numOperands) {
136     auto indexIt = std::next(operands.begin(), start);
137     std::rotate(indexIt, std::next(indexIt, length), operands.end());
138   }
139   for (unsigned i = 0; i != length; ++i)
140     operands[storage.numOperands + i].~OpOperand();
141 }
142 
143 /// Resize the storage to the given size. Returns the array containing the new
144 /// operands.
145 MutableArrayRef<OpOperand> detail::OperandStorage::resize(Operation *owner,
146                                                           unsigned newSize) {
147   TrailingOperandStorage &storage = getStorage();
148 
149   // If the number of operands is less than or equal to the current amount, we
150   // can just update in place.
151   unsigned &numOperands = storage.numOperands;
152   MutableArrayRef<OpOperand> operands = storage.getOperands();
153   if (newSize <= numOperands) {
154     // If the number of new size is less than the current, remove any extra
155     // operands.
156     for (unsigned i = newSize; i != numOperands; ++i)
157       operands[i].~OpOperand();
158     numOperands = newSize;
159     return operands.take_front(newSize);
160   }
161 
162   // If the new size is within the original inline capacity, grow inplace.
163   if (newSize <= storage.capacity) {
164     OpOperand *opBegin = operands.data();
165     for (unsigned e = newSize; numOperands != e; ++numOperands)
166       new (&opBegin[numOperands]) OpOperand(owner);
167     return MutableArrayRef<OpOperand>(opBegin, newSize);
168   }
169 
170   // Otherwise, we need to allocate a new storage.
171   unsigned newCapacity =
172       std::max(unsigned(llvm::NextPowerOf2(storage.capacity + 2)), newSize);
173   auto *newStorageMem =
174       malloc(TrailingOperandStorage::totalSizeToAlloc<OpOperand>(newCapacity));
175   auto *newStorage = ::new (newStorageMem) TrailingOperandStorage();
176   newStorage->numOperands = newSize;
177   newStorage->capacity = newCapacity;
178 
179   // Move the current operands to the new storage.
180   MutableArrayRef<OpOperand> newOperands = newStorage->getOperands();
181   std::uninitialized_copy(std::make_move_iterator(operands.begin()),
182                           std::make_move_iterator(operands.end()),
183                           newOperands.begin());
184 
185   // Destroy the original operands.
186   for (auto &operand : operands)
187     operand.~OpOperand();
188 
189   // Initialize any new operands.
190   for (unsigned e = newSize; numOperands != e; ++numOperands)
191     new (&newOperands[numOperands]) OpOperand(owner);
192 
193   // If the current storage is also dynamic, free it.
194   if (isDynamicStorage())
195     free(&storage);
196 
197   // Update the storage representation to use the new dynamic storage.
198   representation = reinterpret_cast<intptr_t>(newStorage);
199   representation |= DynamicStorageBit;
200   return newOperands;
201 }
202 
203 //===----------------------------------------------------------------------===//
204 // ResultStorage
205 //===----------------------------------------------------------------------===//
206 
207 /// Returns the parent operation of this trailing result.
208 Operation *detail::TrailingOpResult::getOwner() {
209   // We need to do some arithmetic to get the operation pointer. Move the
210   // trailing owner to the start of the array.
211   TrailingOpResult *trailingIt = this - trailingResultNumber;
212 
213   // Move the owner past the inline op results to get to the operation.
214   auto *inlineResultIt = reinterpret_cast<InLineOpResult *>(trailingIt) -
215                          OpResult::getMaxInlineResults();
216   return reinterpret_cast<Operation *>(inlineResultIt) - 1;
217 }
218 
219 //===----------------------------------------------------------------------===//
220 // Operation Value-Iterators
221 //===----------------------------------------------------------------------===//
222 
223 //===----------------------------------------------------------------------===//
224 // TypeRange
225 
226 TypeRange::TypeRange(ArrayRef<Type> types)
227     : TypeRange(types.data(), types.size()) {}
228 TypeRange::TypeRange(OperandRange values)
229     : TypeRange(values.begin().getBase(), values.size()) {}
230 TypeRange::TypeRange(ResultRange values)
231     : TypeRange(values.getBase()->getResultTypes().slice(values.getStartIndex(),
232                                                          values.size())) {}
233 TypeRange::TypeRange(ArrayRef<Value> values)
234     : TypeRange(values.data(), values.size()) {}
235 TypeRange::TypeRange(ValueRange values) : TypeRange(OwnerT(), values.size()) {
236   detail::ValueRangeOwner owner = values.begin().getBase();
237   if (auto *op = reinterpret_cast<Operation *>(owner.ptr.dyn_cast<void *>()))
238     this->base = op->getResultTypes().drop_front(owner.startIndex).data();
239   else if (auto *operand = owner.ptr.dyn_cast<OpOperand *>())
240     this->base = operand;
241   else
242     this->base = owner.ptr.get<const Value *>();
243 }
244 
245 /// See `llvm::detail::indexed_accessor_range_base` for details.
246 TypeRange::OwnerT TypeRange::offset_base(OwnerT object, ptrdiff_t index) {
247   if (auto *value = object.dyn_cast<const Value *>())
248     return {value + index};
249   if (auto *operand = object.dyn_cast<OpOperand *>())
250     return {operand + index};
251   return {object.dyn_cast<const Type *>() + index};
252 }
253 /// See `llvm::detail::indexed_accessor_range_base` for details.
254 Type TypeRange::dereference_iterator(OwnerT object, ptrdiff_t index) {
255   if (auto *value = object.dyn_cast<const Value *>())
256     return (value + index)->getType();
257   if (auto *operand = object.dyn_cast<OpOperand *>())
258     return (operand + index)->get().getType();
259   return object.dyn_cast<const Type *>()[index];
260 }
261 
262 //===----------------------------------------------------------------------===//
263 // OperandRange
264 
265 OperandRange::OperandRange(Operation *op)
266     : OperandRange(op->getOpOperands().data(), op->getNumOperands()) {}
267 
268 /// Return the operand index of the first element of this range. The range
269 /// must not be empty.
270 unsigned OperandRange::getBeginOperandIndex() const {
271   assert(!empty() && "range must not be empty");
272   return base->getOperandNumber();
273 }
274 
275 //===----------------------------------------------------------------------===//
276 // MutableOperandRange
277 
278 /// Construct a new mutable range from the given operand, operand start index,
279 /// and range length.
280 MutableOperandRange::MutableOperandRange(
281     Operation *owner, unsigned start, unsigned length,
282     ArrayRef<OperandSegment> operandSegments)
283     : owner(owner), start(start), length(length),
284       operandSegments(operandSegments.begin(), operandSegments.end()) {
285   assert((start + length) <= owner->getNumOperands() && "invalid range");
286 }
287 MutableOperandRange::MutableOperandRange(Operation *owner)
288     : MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {}
289 
290 /// Slice this range into a sub range, with the additional operand segment.
291 MutableOperandRange
292 MutableOperandRange::slice(unsigned subStart, unsigned subLen,
293                            Optional<OperandSegment> segment) {
294   assert((subStart + subLen) <= length && "invalid sub-range");
295   MutableOperandRange subSlice(owner, start + subStart, subLen,
296                                operandSegments);
297   if (segment)
298     subSlice.operandSegments.push_back(*segment);
299   return subSlice;
300 }
301 
302 /// Append the given values to the range.
303 void MutableOperandRange::append(ValueRange values) {
304   if (values.empty())
305     return;
306   owner->insertOperands(start + length, values);
307   updateLength(length + values.size());
308 }
309 
310 /// Assign this range to the given values.
311 void MutableOperandRange::assign(ValueRange values) {
312   owner->setOperands(start, length, values);
313   if (length != values.size())
314     updateLength(/*newLength=*/values.size());
315 }
316 
317 /// Assign the range to the given value.
318 void MutableOperandRange::assign(Value value) {
319   if (length == 1) {
320     owner->setOperand(start, value);
321   } else {
322     owner->setOperands(start, length, value);
323     updateLength(/*newLength=*/1);
324   }
325 }
326 
327 /// Erase the operands within the given sub-range.
328 void MutableOperandRange::erase(unsigned subStart, unsigned subLen) {
329   assert((subStart + subLen) <= length && "invalid sub-range");
330   if (length == 0)
331     return;
332   owner->eraseOperands(start + subStart, subLen);
333   updateLength(length - subLen);
334 }
335 
336 /// Clear this range and erase all of the operands.
337 void MutableOperandRange::clear() {
338   if (length != 0) {
339     owner->eraseOperands(start, length);
340     updateLength(/*newLength=*/0);
341   }
342 }
343 
344 /// Allow implicit conversion to an OperandRange.
345 MutableOperandRange::operator OperandRange() const {
346   return owner->getOperands().slice(start, length);
347 }
348 
349 /// Update the length of this range to the one provided.
350 void MutableOperandRange::updateLength(unsigned newLength) {
351   int32_t diff = int32_t(newLength) - int32_t(length);
352   length = newLength;
353 
354   // Update any of the provided segment attributes.
355   for (OperandSegment &segment : operandSegments) {
356     auto attr = segment.second.second.cast<DenseIntElementsAttr>();
357     SmallVector<int32_t, 8> segments(attr.getValues<int32_t>());
358     segments[segment.first] += diff;
359     segment.second.second = DenseIntElementsAttr::get(attr.getType(), segments);
360     owner->setAttr(segment.second.first, segment.second.second);
361   }
362 }
363 
364 //===----------------------------------------------------------------------===//
365 // ResultRange
366 
367 ResultRange::ResultRange(Operation *op)
368     : ResultRange(op, /*startIndex=*/0, op->getNumResults()) {}
369 
370 ArrayRef<Type> ResultRange::getTypes() const {
371   return getBase()->getResultTypes().slice(getStartIndex(), size());
372 }
373 
374 /// See `llvm::indexed_accessor_range` for details.
375 OpResult ResultRange::dereference(Operation *op, ptrdiff_t index) {
376   return op->getResult(index);
377 }
378 
379 //===----------------------------------------------------------------------===//
380 // ValueRange
381 
382 ValueRange::ValueRange(ArrayRef<Value> values)
383     : ValueRange(values.data(), values.size()) {}
384 ValueRange::ValueRange(OperandRange values)
385     : ValueRange(values.begin().getBase(), values.size()) {}
386 ValueRange::ValueRange(ResultRange values)
387     : ValueRange(
388           {values.getBase(), static_cast<unsigned>(values.getStartIndex())},
389           values.size()) {}
390 
391 /// See `llvm::detail::indexed_accessor_range_base` for details.
392 ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
393                                            ptrdiff_t index) {
394   if (auto *value = owner.ptr.dyn_cast<const Value *>())
395     return {value + index};
396   if (auto *operand = owner.ptr.dyn_cast<OpOperand *>())
397     return {operand + index};
398   Operation *operation = reinterpret_cast<Operation *>(owner.ptr.get<void *>());
399   return {operation, owner.startIndex + static_cast<unsigned>(index)};
400 }
401 /// See `llvm::detail::indexed_accessor_range_base` for details.
402 Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
403   if (auto *value = owner.ptr.dyn_cast<const Value *>())
404     return value[index];
405   if (auto *operand = owner.ptr.dyn_cast<OpOperand *>())
406     return operand[index].get();
407   Operation *operation = reinterpret_cast<Operation *>(owner.ptr.get<void *>());
408   return operation->getResult(owner.startIndex + index);
409 }
410 
411 //===----------------------------------------------------------------------===//
412 // Operation Equivalency
413 //===----------------------------------------------------------------------===//
414 
415 llvm::hash_code OperationEquivalence::computeHash(Operation *op, Flags flags) {
416   // Hash operations based upon their:
417   //   - Operation Name
418   //   - Attributes
419   llvm::hash_code hash = llvm::hash_combine(
420       op->getName(), op->getMutableAttrDict().getDictionary());
421 
422   //   - Result Types
423   ArrayRef<Type> resultTypes = op->getResultTypes();
424   switch (resultTypes.size()) {
425   case 0:
426     // We don't need to add anything to the hash.
427     break;
428   case 1:
429     // Add in the result type.
430     hash = llvm::hash_combine(hash, resultTypes.front());
431     break;
432   default:
433     // Use the type buffer as the hash, as we can guarantee it is the same for
434     // any given range of result types. This takes advantage of the fact the
435     // result types >1 are stored in a TupleType and uniqued.
436     hash = llvm::hash_combine(hash, resultTypes.data());
437     break;
438   }
439 
440   //   - Operands
441   bool ignoreOperands = flags & Flags::IgnoreOperands;
442   if (!ignoreOperands) {
443     // TODO: Allow commutative operations to have different ordering.
444     hash = llvm::hash_combine(
445         hash, llvm::hash_combine_range(op->operand_begin(), op->operand_end()));
446   }
447   return hash;
448 }
449 
450 bool OperationEquivalence::isEquivalentTo(Operation *lhs, Operation *rhs,
451                                           Flags flags) {
452   if (lhs == rhs)
453     return true;
454 
455   // Compare the operation name.
456   if (lhs->getName() != rhs->getName())
457     return false;
458   // Check operand counts.
459   if (lhs->getNumOperands() != rhs->getNumOperands())
460     return false;
461   // Compare attributes.
462   if (lhs->getMutableAttrDict() != rhs->getMutableAttrDict())
463     return false;
464   // Compare result types.
465   ArrayRef<Type> lhsResultTypes = lhs->getResultTypes();
466   ArrayRef<Type> rhsResultTypes = rhs->getResultTypes();
467   if (lhsResultTypes.size() != rhsResultTypes.size())
468     return false;
469   switch (lhsResultTypes.size()) {
470   case 0:
471     break;
472   case 1:
473     // Compare the single result type.
474     if (lhsResultTypes.front() != rhsResultTypes.front())
475       return false;
476     break;
477   default:
478     // Use the type buffer for the comparison, as we can guarantee it is the
479     // same for any given range of result types. This takes advantage of the
480     // fact the result types >1 are stored in a TupleType and uniqued.
481     if (lhsResultTypes.data() != rhsResultTypes.data())
482       return false;
483     break;
484   }
485   // Compare operands.
486   bool ignoreOperands = flags & Flags::IgnoreOperands;
487   if (ignoreOperands)
488     return true;
489   // TODO: Allow commutative operations to have different ordering.
490   return std::equal(lhs->operand_begin(), lhs->operand_end(),
491                     rhs->operand_begin());
492 }
493