1 //===- OperationSupport.cpp -----------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains out-of-line implementations of the support types that 10 // Operation and related classes build on top of. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/IR/OperationSupport.h" 15 #include "mlir/IR/Block.h" 16 #include "mlir/IR/OpDefinition.h" 17 #include "mlir/IR/Operation.h" 18 #include "mlir/IR/StandardTypes.h" 19 using namespace mlir; 20 21 //===----------------------------------------------------------------------===// 22 // OperationState 23 //===----------------------------------------------------------------------===// 24 25 OperationState::OperationState(Location location, StringRef name) 26 : location(location), name(name, location->getContext()) {} 27 28 OperationState::OperationState(Location location, OperationName name) 29 : location(location), name(name) {} 30 31 OperationState::OperationState(Location location, StringRef name, 32 ValueRange operands, ArrayRef<Type> types, 33 ArrayRef<NamedAttribute> attributes, 34 ArrayRef<Block *> successors, 35 MutableArrayRef<std::unique_ptr<Region>> regions) 36 : location(location), name(name, location->getContext()), 37 operands(operands.begin(), operands.end()), 38 types(types.begin(), types.end()), 39 attributes(attributes.begin(), attributes.end()), 40 successors(successors.begin(), successors.end()) { 41 for (std::unique_ptr<Region> &r : regions) 42 this->regions.push_back(std::move(r)); 43 } 44 45 void OperationState::addOperands(ValueRange newOperands) { 46 operands.append(newOperands.begin(), newOperands.end()); 47 } 48 49 void OperationState::addSuccessors(SuccessorRange newSuccessors) { 50 successors.append(newSuccessors.begin(), newSuccessors.end()); 51 } 52 53 Region *OperationState::addRegion() { 54 regions.emplace_back(new Region); 55 return regions.back().get(); 56 } 57 58 void OperationState::addRegion(std::unique_ptr<Region> &®ion) { 59 regions.push_back(std::move(region)); 60 } 61 62 //===----------------------------------------------------------------------===// 63 // OperandStorage 64 //===----------------------------------------------------------------------===// 65 66 detail::OperandStorage::OperandStorage(Operation *owner, ValueRange values) 67 : representation(0) { 68 auto &inlineStorage = getInlineStorage(); 69 inlineStorage.numOperands = inlineStorage.capacity = values.size(); 70 auto *operandPtrBegin = getTrailingObjects<OpOperand>(); 71 for (unsigned i = 0, e = inlineStorage.numOperands; i < e; ++i) 72 new (&operandPtrBegin[i]) OpOperand(owner, values[i]); 73 } 74 75 detail::OperandStorage::~OperandStorage() { 76 // Destruct the current storage container. 77 if (isDynamicStorage()) { 78 TrailingOperandStorage &storage = getDynamicStorage(); 79 storage.~TrailingOperandStorage(); 80 free(&storage); 81 } else { 82 getInlineStorage().~TrailingOperandStorage(); 83 } 84 } 85 86 /// Replace the operands contained in the storage with the ones provided in 87 /// 'values'. 88 void detail::OperandStorage::setOperands(Operation *owner, ValueRange values) { 89 MutableArrayRef<OpOperand> storageOperands = resize(owner, values.size()); 90 for (unsigned i = 0, e = values.size(); i != e; ++i) 91 storageOperands[i].set(values[i]); 92 } 93 94 /// Replace the operands beginning at 'start' and ending at 'start' + 'length' 95 /// with the ones provided in 'operands'. 'operands' may be smaller or larger 96 /// than the range pointed to by 'start'+'length'. 97 void detail::OperandStorage::setOperands(Operation *owner, unsigned start, 98 unsigned length, ValueRange operands) { 99 // If the new size is the same, we can update inplace. 100 unsigned newSize = operands.size(); 101 if (newSize == length) { 102 MutableArrayRef<OpOperand> storageOperands = getOperands(); 103 for (unsigned i = 0, e = length; i != e; ++i) 104 storageOperands[start + i].set(operands[i]); 105 return; 106 } 107 // If the new size is greater, remove the extra operands and set the rest 108 // inplace. 109 if (newSize < length) { 110 eraseOperands(start + operands.size(), length - newSize); 111 setOperands(owner, start, newSize, operands); 112 return; 113 } 114 // Otherwise, the new size is greater so we need to grow the storage. 115 auto storageOperands = resize(owner, size() + (newSize - length)); 116 117 // Shift operands to the right to make space for the new operands. 118 unsigned rotateSize = storageOperands.size() - (start + length); 119 auto rbegin = storageOperands.rbegin(); 120 std::rotate(rbegin, std::next(rbegin, newSize - length), rbegin + rotateSize); 121 122 // Update the operands inplace. 123 for (unsigned i = 0, e = operands.size(); i != e; ++i) 124 storageOperands[start + i].set(operands[i]); 125 } 126 127 /// Erase an operand held by the storage. 128 void detail::OperandStorage::eraseOperands(unsigned start, unsigned length) { 129 TrailingOperandStorage &storage = getStorage(); 130 MutableArrayRef<OpOperand> operands = storage.getOperands(); 131 assert((start + length) <= operands.size()); 132 storage.numOperands -= length; 133 134 // Shift all operands down if the operand to remove is not at the end. 135 if (start != storage.numOperands) { 136 auto indexIt = std::next(operands.begin(), start); 137 std::rotate(indexIt, std::next(indexIt, length), operands.end()); 138 } 139 for (unsigned i = 0; i != length; ++i) 140 operands[storage.numOperands + i].~OpOperand(); 141 } 142 143 /// Resize the storage to the given size. Returns the array containing the new 144 /// operands. 145 MutableArrayRef<OpOperand> detail::OperandStorage::resize(Operation *owner, 146 unsigned newSize) { 147 TrailingOperandStorage &storage = getStorage(); 148 149 // If the number of operands is less than or equal to the current amount, we 150 // can just update in place. 151 unsigned &numOperands = storage.numOperands; 152 MutableArrayRef<OpOperand> operands = storage.getOperands(); 153 if (newSize <= numOperands) { 154 // If the number of new size is less than the current, remove any extra 155 // operands. 156 for (unsigned i = newSize; i != numOperands; ++i) 157 operands[i].~OpOperand(); 158 numOperands = newSize; 159 return operands.take_front(newSize); 160 } 161 162 // If the new size is within the original inline capacity, grow inplace. 163 if (newSize <= storage.capacity) { 164 OpOperand *opBegin = operands.data(); 165 for (unsigned e = newSize; numOperands != e; ++numOperands) 166 new (&opBegin[numOperands]) OpOperand(owner); 167 return MutableArrayRef<OpOperand>(opBegin, newSize); 168 } 169 170 // Otherwise, we need to allocate a new storage. 171 unsigned newCapacity = 172 std::max(unsigned(llvm::NextPowerOf2(storage.capacity + 2)), newSize); 173 auto *newStorageMem = 174 malloc(TrailingOperandStorage::totalSizeToAlloc<OpOperand>(newCapacity)); 175 auto *newStorage = ::new (newStorageMem) TrailingOperandStorage(); 176 newStorage->numOperands = newSize; 177 newStorage->capacity = newCapacity; 178 179 // Move the current operands to the new storage. 180 MutableArrayRef<OpOperand> newOperands = newStorage->getOperands(); 181 std::uninitialized_copy(std::make_move_iterator(operands.begin()), 182 std::make_move_iterator(operands.end()), 183 newOperands.begin()); 184 185 // Destroy the original operands. 186 for (auto &operand : operands) 187 operand.~OpOperand(); 188 189 // Initialize any new operands. 190 for (unsigned e = newSize; numOperands != e; ++numOperands) 191 new (&newOperands[numOperands]) OpOperand(owner); 192 193 // If the current storage is also dynamic, free it. 194 if (isDynamicStorage()) 195 free(&storage); 196 197 // Update the storage representation to use the new dynamic storage. 198 representation = reinterpret_cast<intptr_t>(newStorage); 199 representation |= DynamicStorageBit; 200 return newOperands; 201 } 202 203 //===----------------------------------------------------------------------===// 204 // ResultStorage 205 //===----------------------------------------------------------------------===// 206 207 /// Returns the parent operation of this trailing result. 208 Operation *detail::TrailingOpResult::getOwner() { 209 // We need to do some arithmetic to get the operation pointer. Move the 210 // trailing owner to the start of the array. 211 TrailingOpResult *trailingIt = this - trailingResultNumber; 212 213 // Move the owner past the inline op results to get to the operation. 214 auto *inlineResultIt = reinterpret_cast<InLineOpResult *>(trailingIt) - 215 OpResult::getMaxInlineResults(); 216 return reinterpret_cast<Operation *>(inlineResultIt) - 1; 217 } 218 219 //===----------------------------------------------------------------------===// 220 // Operation Value-Iterators 221 //===----------------------------------------------------------------------===// 222 223 //===----------------------------------------------------------------------===// 224 // TypeRange 225 226 TypeRange::TypeRange(ArrayRef<Type> types) 227 : TypeRange(types.data(), types.size()) {} 228 TypeRange::TypeRange(OperandRange values) 229 : TypeRange(values.begin().getBase(), values.size()) {} 230 TypeRange::TypeRange(ResultRange values) 231 : TypeRange(values.getBase()->getResultTypes().slice(values.getStartIndex(), 232 values.size())) {} 233 TypeRange::TypeRange(ArrayRef<Value> values) 234 : TypeRange(values.data(), values.size()) {} 235 TypeRange::TypeRange(ValueRange values) : TypeRange(OwnerT(), values.size()) { 236 detail::ValueRangeOwner owner = values.begin().getBase(); 237 if (auto *op = reinterpret_cast<Operation *>(owner.ptr.dyn_cast<void *>())) 238 this->base = op->getResultTypes().drop_front(owner.startIndex).data(); 239 else if (auto *operand = owner.ptr.dyn_cast<OpOperand *>()) 240 this->base = operand; 241 else 242 this->base = owner.ptr.get<const Value *>(); 243 } 244 245 /// See `llvm::detail::indexed_accessor_range_base` for details. 246 TypeRange::OwnerT TypeRange::offset_base(OwnerT object, ptrdiff_t index) { 247 if (auto *value = object.dyn_cast<const Value *>()) 248 return {value + index}; 249 if (auto *operand = object.dyn_cast<OpOperand *>()) 250 return {operand + index}; 251 return {object.dyn_cast<const Type *>() + index}; 252 } 253 /// See `llvm::detail::indexed_accessor_range_base` for details. 254 Type TypeRange::dereference_iterator(OwnerT object, ptrdiff_t index) { 255 if (auto *value = object.dyn_cast<const Value *>()) 256 return (value + index)->getType(); 257 if (auto *operand = object.dyn_cast<OpOperand *>()) 258 return (operand + index)->get().getType(); 259 return object.dyn_cast<const Type *>()[index]; 260 } 261 262 //===----------------------------------------------------------------------===// 263 // OperandRange 264 265 OperandRange::OperandRange(Operation *op) 266 : OperandRange(op->getOpOperands().data(), op->getNumOperands()) {} 267 268 /// Return the operand index of the first element of this range. The range 269 /// must not be empty. 270 unsigned OperandRange::getBeginOperandIndex() const { 271 assert(!empty() && "range must not be empty"); 272 return base->getOperandNumber(); 273 } 274 275 //===----------------------------------------------------------------------===// 276 // MutableOperandRange 277 278 /// Construct a new mutable range from the given operand, operand start index, 279 /// and range length. 280 MutableOperandRange::MutableOperandRange( 281 Operation *owner, unsigned start, unsigned length, 282 ArrayRef<OperandSegment> operandSegments) 283 : owner(owner), start(start), length(length), 284 operandSegments(operandSegments.begin(), operandSegments.end()) { 285 assert((start + length) <= owner->getNumOperands() && "invalid range"); 286 } 287 MutableOperandRange::MutableOperandRange(Operation *owner) 288 : MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {} 289 290 /// Slice this range into a sub range, with the additional operand segment. 291 MutableOperandRange 292 MutableOperandRange::slice(unsigned subStart, unsigned subLen, 293 Optional<OperandSegment> segment) { 294 assert((subStart + subLen) <= length && "invalid sub-range"); 295 MutableOperandRange subSlice(owner, start + subStart, subLen, 296 operandSegments); 297 if (segment) 298 subSlice.operandSegments.push_back(*segment); 299 return subSlice; 300 } 301 302 /// Append the given values to the range. 303 void MutableOperandRange::append(ValueRange values) { 304 if (values.empty()) 305 return; 306 owner->insertOperands(start + length, values); 307 updateLength(length + values.size()); 308 } 309 310 /// Assign this range to the given values. 311 void MutableOperandRange::assign(ValueRange values) { 312 owner->setOperands(start, length, values); 313 if (length != values.size()) 314 updateLength(/*newLength=*/values.size()); 315 } 316 317 /// Assign the range to the given value. 318 void MutableOperandRange::assign(Value value) { 319 if (length == 1) { 320 owner->setOperand(start, value); 321 } else { 322 owner->setOperands(start, length, value); 323 updateLength(/*newLength=*/1); 324 } 325 } 326 327 /// Erase the operands within the given sub-range. 328 void MutableOperandRange::erase(unsigned subStart, unsigned subLen) { 329 assert((subStart + subLen) <= length && "invalid sub-range"); 330 if (length == 0) 331 return; 332 owner->eraseOperands(start + subStart, subLen); 333 updateLength(length - subLen); 334 } 335 336 /// Clear this range and erase all of the operands. 337 void MutableOperandRange::clear() { 338 if (length != 0) { 339 owner->eraseOperands(start, length); 340 updateLength(/*newLength=*/0); 341 } 342 } 343 344 /// Allow implicit conversion to an OperandRange. 345 MutableOperandRange::operator OperandRange() const { 346 return owner->getOperands().slice(start, length); 347 } 348 349 /// Update the length of this range to the one provided. 350 void MutableOperandRange::updateLength(unsigned newLength) { 351 int32_t diff = int32_t(newLength) - int32_t(length); 352 length = newLength; 353 354 // Update any of the provided segment attributes. 355 for (OperandSegment &segment : operandSegments) { 356 auto attr = segment.second.second.cast<DenseIntElementsAttr>(); 357 SmallVector<int32_t, 8> segments(attr.getValues<int32_t>()); 358 segments[segment.first] += diff; 359 segment.second.second = DenseIntElementsAttr::get(attr.getType(), segments); 360 owner->setAttr(segment.second.first, segment.second.second); 361 } 362 } 363 364 //===----------------------------------------------------------------------===// 365 // ResultRange 366 367 ResultRange::ResultRange(Operation *op) 368 : ResultRange(op, /*startIndex=*/0, op->getNumResults()) {} 369 370 ArrayRef<Type> ResultRange::getTypes() const { 371 return getBase()->getResultTypes().slice(getStartIndex(), size()); 372 } 373 374 /// See `llvm::indexed_accessor_range` for details. 375 OpResult ResultRange::dereference(Operation *op, ptrdiff_t index) { 376 return op->getResult(index); 377 } 378 379 //===----------------------------------------------------------------------===// 380 // ValueRange 381 382 ValueRange::ValueRange(ArrayRef<Value> values) 383 : ValueRange(values.data(), values.size()) {} 384 ValueRange::ValueRange(OperandRange values) 385 : ValueRange(values.begin().getBase(), values.size()) {} 386 ValueRange::ValueRange(ResultRange values) 387 : ValueRange( 388 {values.getBase(), static_cast<unsigned>(values.getStartIndex())}, 389 values.size()) {} 390 391 /// See `llvm::detail::indexed_accessor_range_base` for details. 392 ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner, 393 ptrdiff_t index) { 394 if (auto *value = owner.ptr.dyn_cast<const Value *>()) 395 return {value + index}; 396 if (auto *operand = owner.ptr.dyn_cast<OpOperand *>()) 397 return {operand + index}; 398 Operation *operation = reinterpret_cast<Operation *>(owner.ptr.get<void *>()); 399 return {operation, owner.startIndex + static_cast<unsigned>(index)}; 400 } 401 /// See `llvm::detail::indexed_accessor_range_base` for details. 402 Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) { 403 if (auto *value = owner.ptr.dyn_cast<const Value *>()) 404 return value[index]; 405 if (auto *operand = owner.ptr.dyn_cast<OpOperand *>()) 406 return operand[index].get(); 407 Operation *operation = reinterpret_cast<Operation *>(owner.ptr.get<void *>()); 408 return operation->getResult(owner.startIndex + index); 409 } 410 411 //===----------------------------------------------------------------------===// 412 // Operation Equivalency 413 //===----------------------------------------------------------------------===// 414 415 llvm::hash_code OperationEquivalence::computeHash(Operation *op, Flags flags) { 416 // Hash operations based upon their: 417 // - Operation Name 418 // - Attributes 419 llvm::hash_code hash = llvm::hash_combine( 420 op->getName(), op->getMutableAttrDict().getDictionary()); 421 422 // - Result Types 423 ArrayRef<Type> resultTypes = op->getResultTypes(); 424 switch (resultTypes.size()) { 425 case 0: 426 // We don't need to add anything to the hash. 427 break; 428 case 1: 429 // Add in the result type. 430 hash = llvm::hash_combine(hash, resultTypes.front()); 431 break; 432 default: 433 // Use the type buffer as the hash, as we can guarantee it is the same for 434 // any given range of result types. This takes advantage of the fact the 435 // result types >1 are stored in a TupleType and uniqued. 436 hash = llvm::hash_combine(hash, resultTypes.data()); 437 break; 438 } 439 440 // - Operands 441 bool ignoreOperands = flags & Flags::IgnoreOperands; 442 if (!ignoreOperands) { 443 // TODO: Allow commutative operations to have different ordering. 444 hash = llvm::hash_combine( 445 hash, llvm::hash_combine_range(op->operand_begin(), op->operand_end())); 446 } 447 return hash; 448 } 449 450 bool OperationEquivalence::isEquivalentTo(Operation *lhs, Operation *rhs, 451 Flags flags) { 452 if (lhs == rhs) 453 return true; 454 455 // Compare the operation name. 456 if (lhs->getName() != rhs->getName()) 457 return false; 458 // Check operand counts. 459 if (lhs->getNumOperands() != rhs->getNumOperands()) 460 return false; 461 // Compare attributes. 462 if (lhs->getMutableAttrDict() != rhs->getMutableAttrDict()) 463 return false; 464 // Compare result types. 465 ArrayRef<Type> lhsResultTypes = lhs->getResultTypes(); 466 ArrayRef<Type> rhsResultTypes = rhs->getResultTypes(); 467 if (lhsResultTypes.size() != rhsResultTypes.size()) 468 return false; 469 switch (lhsResultTypes.size()) { 470 case 0: 471 break; 472 case 1: 473 // Compare the single result type. 474 if (lhsResultTypes.front() != rhsResultTypes.front()) 475 return false; 476 break; 477 default: 478 // Use the type buffer for the comparison, as we can guarantee it is the 479 // same for any given range of result types. This takes advantage of the 480 // fact the result types >1 are stored in a TupleType and uniqued. 481 if (lhsResultTypes.data() != rhsResultTypes.data()) 482 return false; 483 break; 484 } 485 // Compare operands. 486 bool ignoreOperands = flags & Flags::IgnoreOperands; 487 if (ignoreOperands) 488 return true; 489 // TODO: Allow commutative operations to have different ordering. 490 return std::equal(lhs->operand_begin(), lhs->operand_end(), 491 rhs->operand_begin()); 492 } 493