1 //===- Operation.cpp - Operation support code -----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Operation.h" 10 #include "mlir/IR/BlockAndValueMapping.h" 11 #include "mlir/IR/BuiltinTypes.h" 12 #include "mlir/IR/Dialect.h" 13 #include "mlir/IR/OpImplementation.h" 14 #include "mlir/IR/PatternMatch.h" 15 #include "mlir/IR/TypeUtilities.h" 16 #include "mlir/Interfaces/FoldInterfaces.h" 17 #include <numeric> 18 19 using namespace mlir; 20 21 OpAsmParser::~OpAsmParser() {} 22 23 //===----------------------------------------------------------------------===// 24 // OperationName 25 //===----------------------------------------------------------------------===// 26 27 /// Form the OperationName for an op with the specified string. This either is 28 /// a reference to an AbstractOperation if one is known, or a uniqued Identifier 29 /// if not. 30 OperationName::OperationName(StringRef name, MLIRContext *context) { 31 if (auto *op = AbstractOperation::lookup(name, context)) 32 representation = op; 33 else 34 representation = Identifier::get(name, context); 35 } 36 37 /// Return the name of the dialect this operation is registered to. 38 StringRef OperationName::getDialectNamespace() const { 39 if (Dialect *dialect = getDialect()) 40 return dialect->getNamespace(); 41 return representation.get<Identifier>().strref().split('.').first; 42 } 43 44 /// Return the operation name with dialect name stripped, if it has one. 45 StringRef OperationName::stripDialect() const { 46 auto splitName = getStringRef().split("."); 47 return splitName.second.empty() ? splitName.first : splitName.second; 48 } 49 50 /// Return the name of this operation. This always succeeds. 51 StringRef OperationName::getStringRef() const { 52 return getIdentifier().strref(); 53 } 54 55 /// Return the name of this operation as an identifier. This always succeeds. 56 Identifier OperationName::getIdentifier() const { 57 if (auto *op = representation.dyn_cast<const AbstractOperation *>()) 58 return op->name; 59 return representation.get<Identifier>(); 60 } 61 62 OperationName OperationName::getFromOpaquePointer(const void *pointer) { 63 return OperationName( 64 RepresentationUnion::getFromOpaqueValue(const_cast<void *>(pointer))); 65 } 66 67 //===----------------------------------------------------------------------===// 68 // Operation 69 //===----------------------------------------------------------------------===// 70 71 /// Create a new Operation with the specific fields. 72 Operation *Operation::create(Location location, OperationName name, 73 TypeRange resultTypes, ValueRange operands, 74 ArrayRef<NamedAttribute> attributes, 75 BlockRange successors, unsigned numRegions) { 76 return create(location, name, resultTypes, operands, 77 DictionaryAttr::get(location.getContext(), attributes), 78 successors, numRegions); 79 } 80 81 /// Create a new Operation from operation state. 82 Operation *Operation::create(const OperationState &state) { 83 return create(state.location, state.name, state.types, state.operands, 84 state.attributes.getDictionary(state.getContext()), 85 state.successors, state.regions); 86 } 87 88 /// Create a new Operation with the specific fields. 89 Operation *Operation::create(Location location, OperationName name, 90 TypeRange resultTypes, ValueRange operands, 91 DictionaryAttr attributes, BlockRange successors, 92 RegionRange regions) { 93 unsigned numRegions = regions.size(); 94 Operation *op = create(location, name, resultTypes, operands, attributes, 95 successors, numRegions); 96 for (unsigned i = 0; i < numRegions; ++i) 97 if (regions[i]) 98 op->getRegion(i).takeBody(*regions[i]); 99 return op; 100 } 101 102 /// Overload of create that takes an existing DictionaryAttr to avoid 103 /// unnecessarily uniquing a list of attributes. 104 Operation *Operation::create(Location location, OperationName name, 105 TypeRange resultTypes, ValueRange operands, 106 DictionaryAttr attributes, BlockRange successors, 107 unsigned numRegions) { 108 assert(llvm::all_of(resultTypes, [](Type t) { return t; }) && 109 "unexpected null result type"); 110 111 // We only need to allocate additional memory for a subset of results. 112 unsigned numTrailingResults = OpResult::getNumTrailing(resultTypes.size()); 113 unsigned numInlineResults = OpResult::getNumInline(resultTypes.size()); 114 unsigned numSuccessors = successors.size(); 115 unsigned numOperands = operands.size(); 116 unsigned numResults = resultTypes.size(); 117 118 // If the operation is known to have no operands, don't allocate an operand 119 // storage. 120 bool needsOperandStorage = true; 121 if (operands.empty()) { 122 if (const AbstractOperation *abstractOp = name.getAbstractOperation()) 123 needsOperandStorage = !abstractOp->hasTrait<OpTrait::ZeroOperands>(); 124 } 125 126 // Compute the byte size for the operation and the operand storage. This takes 127 // into account the size of the operation, its trailing objects, and its 128 // prefixed objects. 129 size_t byteSize = 130 totalSizeToAlloc<BlockOperand, Region, detail::OperandStorage>( 131 numSuccessors, numRegions, needsOperandStorage ? 1 : 0) + 132 detail::OperandStorage::additionalAllocSize(numOperands); 133 size_t prefixByteSize = llvm::alignTo( 134 Operation::prefixAllocSize(numTrailingResults, numInlineResults), 135 alignof(Operation)); 136 char *mallocMem = reinterpret_cast<char *>(malloc(byteSize + prefixByteSize)); 137 void *rawMem = mallocMem + prefixByteSize; 138 139 // Create the new Operation. 140 Operation *op = 141 ::new (rawMem) Operation(location, name, numResults, numSuccessors, 142 numRegions, attributes, needsOperandStorage); 143 144 assert((numSuccessors == 0 || op->mightHaveTrait<OpTrait::IsTerminator>()) && 145 "unexpected successors in a non-terminator operation"); 146 147 // Initialize the results. 148 auto resultTypeIt = resultTypes.begin(); 149 for (unsigned i = 0; i < numInlineResults; ++i, ++resultTypeIt) 150 new (op->getInlineOpResult(i)) detail::InlineOpResult(*resultTypeIt, i); 151 for (unsigned i = 0; i < numTrailingResults; ++i, ++resultTypeIt) { 152 new (op->getOutOfLineOpResult(i)) 153 detail::OutOfLineOpResult(*resultTypeIt, i); 154 } 155 156 // Initialize the regions. 157 for (unsigned i = 0; i != numRegions; ++i) 158 new (&op->getRegion(i)) Region(op); 159 160 // Initialize the operands. 161 if (needsOperandStorage) 162 new (&op->getOperandStorage()) detail::OperandStorage(op, operands); 163 164 // Initialize the successors. 165 auto blockOperands = op->getBlockOperands(); 166 for (unsigned i = 0; i != numSuccessors; ++i) 167 new (&blockOperands[i]) BlockOperand(op, successors[i]); 168 169 return op; 170 } 171 172 Operation::Operation(Location location, OperationName name, unsigned numResults, 173 unsigned numSuccessors, unsigned numRegions, 174 DictionaryAttr attributes, bool hasOperandStorage) 175 : location(location), numResults(numResults), numSuccs(numSuccessors), 176 numRegions(numRegions), hasOperandStorage(hasOperandStorage), name(name), 177 attrs(attributes) { 178 assert(attributes && "unexpected null attribute dictionary"); 179 } 180 181 // Operations are deleted through the destroy() member because they are 182 // allocated via malloc. 183 Operation::~Operation() { 184 assert(block == nullptr && "operation destroyed but still in a block"); 185 186 // Explicitly run the destructors for the operands. 187 if (hasOperandStorage) 188 getOperandStorage().~OperandStorage(); 189 190 // Explicitly run the destructors for the successors. 191 for (auto &successor : getBlockOperands()) 192 successor.~BlockOperand(); 193 194 // Explicitly destroy the regions. 195 for (auto ®ion : getRegions()) 196 region.~Region(); 197 } 198 199 /// Destroy this operation or one of its subclasses. 200 void Operation::destroy() { 201 // Operations may have additional prefixed allocation, which needs to be 202 // accounted for here when computing the address to free. 203 char *rawMem = reinterpret_cast<char *>(this) - 204 llvm::alignTo(prefixAllocSize(), alignof(Operation)); 205 this->~Operation(); 206 free(rawMem); 207 } 208 209 /// Return the context this operation is associated with. 210 MLIRContext *Operation::getContext() { return location->getContext(); } 211 212 /// Return the dialect this operation is associated with, or nullptr if the 213 /// associated dialect is not registered. 214 Dialect *Operation::getDialect() { return getName().getDialect(); } 215 216 Region *Operation::getParentRegion() { 217 return block ? block->getParent() : nullptr; 218 } 219 220 Operation *Operation::getParentOp() { 221 return block ? block->getParentOp() : nullptr; 222 } 223 224 /// Return true if this operation is a proper ancestor of the `other` 225 /// operation. 226 bool Operation::isProperAncestor(Operation *other) { 227 while ((other = other->getParentOp())) 228 if (this == other) 229 return true; 230 return false; 231 } 232 233 /// Replace any uses of 'from' with 'to' within this operation. 234 void Operation::replaceUsesOfWith(Value from, Value to) { 235 if (from == to) 236 return; 237 for (auto &operand : getOpOperands()) 238 if (operand.get() == from) 239 operand.set(to); 240 } 241 242 /// Replace the current operands of this operation with the ones provided in 243 /// 'operands'. 244 void Operation::setOperands(ValueRange operands) { 245 if (LLVM_LIKELY(hasOperandStorage)) 246 return getOperandStorage().setOperands(this, operands); 247 assert(operands.empty() && "setting operands without an operand storage"); 248 } 249 250 /// Replace the operands beginning at 'start' and ending at 'start' + 'length' 251 /// with the ones provided in 'operands'. 'operands' may be smaller or larger 252 /// than the range pointed to by 'start'+'length'. 253 void Operation::setOperands(unsigned start, unsigned length, 254 ValueRange operands) { 255 assert((start + length) <= getNumOperands() && 256 "invalid operand range specified"); 257 if (LLVM_LIKELY(hasOperandStorage)) 258 return getOperandStorage().setOperands(this, start, length, operands); 259 assert(operands.empty() && "setting operands without an operand storage"); 260 } 261 262 /// Insert the given operands into the operand list at the given 'index'. 263 void Operation::insertOperands(unsigned index, ValueRange operands) { 264 if (LLVM_LIKELY(hasOperandStorage)) 265 return setOperands(index, /*length=*/0, operands); 266 assert(operands.empty() && "inserting operands without an operand storage"); 267 } 268 269 //===----------------------------------------------------------------------===// 270 // Diagnostics 271 //===----------------------------------------------------------------------===// 272 273 /// Emit an error about fatal conditions with this operation, reporting up to 274 /// any diagnostic handlers that may be listening. 275 InFlightDiagnostic Operation::emitError(const Twine &message) { 276 InFlightDiagnostic diag = mlir::emitError(getLoc(), message); 277 if (getContext()->shouldPrintOpOnDiagnostic()) { 278 // Print out the operation explicitly here so that we can print the generic 279 // form. 280 // TODO: It would be nice if we could instead provide the 281 // specific printing flags when adding the operation as an argument to the 282 // diagnostic. 283 std::string printedOp; 284 { 285 llvm::raw_string_ostream os(printedOp); 286 print(os, OpPrintingFlags().printGenericOpForm().useLocalScope()); 287 } 288 diag.attachNote(getLoc()) << "see current operation: " << printedOp; 289 } 290 return diag; 291 } 292 293 /// Emit a warning about this operation, reporting up to any diagnostic 294 /// handlers that may be listening. 295 InFlightDiagnostic Operation::emitWarning(const Twine &message) { 296 InFlightDiagnostic diag = mlir::emitWarning(getLoc(), message); 297 if (getContext()->shouldPrintOpOnDiagnostic()) 298 diag.attachNote(getLoc()) << "see current operation: " << *this; 299 return diag; 300 } 301 302 /// Emit a remark about this operation, reporting up to any diagnostic 303 /// handlers that may be listening. 304 InFlightDiagnostic Operation::emitRemark(const Twine &message) { 305 InFlightDiagnostic diag = mlir::emitRemark(getLoc(), message); 306 if (getContext()->shouldPrintOpOnDiagnostic()) 307 diag.attachNote(getLoc()) << "see current operation: " << *this; 308 return diag; 309 } 310 311 //===----------------------------------------------------------------------===// 312 // Operation Ordering 313 //===----------------------------------------------------------------------===// 314 315 constexpr unsigned Operation::kInvalidOrderIdx; 316 constexpr unsigned Operation::kOrderStride; 317 318 /// Given an operation 'other' that is within the same parent block, return 319 /// whether the current operation is before 'other' in the operation list 320 /// of the parent block. 321 /// Note: This function has an average complexity of O(1), but worst case may 322 /// take O(N) where N is the number of operations within the parent block. 323 bool Operation::isBeforeInBlock(Operation *other) { 324 assert(block && "Operations without parent blocks have no order."); 325 assert(other && other->block == block && 326 "Expected other operation to have the same parent block."); 327 // If the order of the block is already invalid, directly recompute the 328 // parent. 329 if (!block->isOpOrderValid()) { 330 block->recomputeOpOrder(); 331 } else { 332 // Update the order either operation if necessary. 333 updateOrderIfNecessary(); 334 other->updateOrderIfNecessary(); 335 } 336 337 return orderIndex < other->orderIndex; 338 } 339 340 /// Update the order index of this operation of this operation if necessary, 341 /// potentially recomputing the order of the parent block. 342 void Operation::updateOrderIfNecessary() { 343 assert(block && "expected valid parent"); 344 345 // If the order is valid for this operation there is nothing to do. 346 if (hasValidOrder()) 347 return; 348 Operation *blockFront = &block->front(); 349 Operation *blockBack = &block->back(); 350 351 // This method is expected to only be invoked on blocks with more than one 352 // operation. 353 assert(blockFront != blockBack && "expected more than one operation"); 354 355 // If the operation is at the end of the block. 356 if (this == blockBack) { 357 Operation *prevNode = getPrevNode(); 358 if (!prevNode->hasValidOrder()) 359 return block->recomputeOpOrder(); 360 361 // Add the stride to the previous operation. 362 orderIndex = prevNode->orderIndex + kOrderStride; 363 return; 364 } 365 366 // If this is the first operation try to use the next operation to compute the 367 // ordering. 368 if (this == blockFront) { 369 Operation *nextNode = getNextNode(); 370 if (!nextNode->hasValidOrder()) 371 return block->recomputeOpOrder(); 372 // There is no order to give this operation. 373 if (nextNode->orderIndex == 0) 374 return block->recomputeOpOrder(); 375 376 // If we can't use the stride, just take the middle value left. This is safe 377 // because we know there is at least one valid index to assign to. 378 if (nextNode->orderIndex <= kOrderStride) 379 orderIndex = (nextNode->orderIndex / 2); 380 else 381 orderIndex = kOrderStride; 382 return; 383 } 384 385 // Otherwise, this operation is between two others. Place this operation in 386 // the middle of the previous and next if possible. 387 Operation *prevNode = getPrevNode(), *nextNode = getNextNode(); 388 if (!prevNode->hasValidOrder() || !nextNode->hasValidOrder()) 389 return block->recomputeOpOrder(); 390 unsigned prevOrder = prevNode->orderIndex, nextOrder = nextNode->orderIndex; 391 392 // Check to see if there is a valid order between the two. 393 if (prevOrder + 1 == nextOrder) 394 return block->recomputeOpOrder(); 395 orderIndex = prevOrder + ((nextOrder - prevOrder) / 2); 396 } 397 398 //===----------------------------------------------------------------------===// 399 // ilist_traits for Operation 400 //===----------------------------------------------------------------------===// 401 402 auto llvm::ilist_detail::SpecificNodeAccess< 403 typename llvm::ilist_detail::compute_node_options< 404 ::mlir::Operation>::type>::getNodePtr(pointer N) -> node_type * { 405 return NodeAccess::getNodePtr<OptionsT>(N); 406 } 407 408 auto llvm::ilist_detail::SpecificNodeAccess< 409 typename llvm::ilist_detail::compute_node_options< 410 ::mlir::Operation>::type>::getNodePtr(const_pointer N) 411 -> const node_type * { 412 return NodeAccess::getNodePtr<OptionsT>(N); 413 } 414 415 auto llvm::ilist_detail::SpecificNodeAccess< 416 typename llvm::ilist_detail::compute_node_options< 417 ::mlir::Operation>::type>::getValuePtr(node_type *N) -> pointer { 418 return NodeAccess::getValuePtr<OptionsT>(N); 419 } 420 421 auto llvm::ilist_detail::SpecificNodeAccess< 422 typename llvm::ilist_detail::compute_node_options< 423 ::mlir::Operation>::type>::getValuePtr(const node_type *N) 424 -> const_pointer { 425 return NodeAccess::getValuePtr<OptionsT>(N); 426 } 427 428 void llvm::ilist_traits<::mlir::Operation>::deleteNode(Operation *op) { 429 op->destroy(); 430 } 431 432 Block *llvm::ilist_traits<::mlir::Operation>::getContainingBlock() { 433 size_t Offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr)))); 434 iplist<Operation> *Anchor(static_cast<iplist<Operation> *>(this)); 435 return reinterpret_cast<Block *>(reinterpret_cast<char *>(Anchor) - Offset); 436 } 437 438 /// This is a trait method invoked when an operation is added to a block. We 439 /// keep the block pointer up to date. 440 void llvm::ilist_traits<::mlir::Operation>::addNodeToList(Operation *op) { 441 assert(!op->getBlock() && "already in an operation block!"); 442 op->block = getContainingBlock(); 443 444 // Invalidate the order on the operation. 445 op->orderIndex = Operation::kInvalidOrderIdx; 446 } 447 448 /// This is a trait method invoked when an operation is removed from a block. 449 /// We keep the block pointer up to date. 450 void llvm::ilist_traits<::mlir::Operation>::removeNodeFromList(Operation *op) { 451 assert(op->block && "not already in an operation block!"); 452 op->block = nullptr; 453 } 454 455 /// This is a trait method invoked when an operation is moved from one block 456 /// to another. We keep the block pointer up to date. 457 void llvm::ilist_traits<::mlir::Operation>::transferNodesFromList( 458 ilist_traits<Operation> &otherList, op_iterator first, op_iterator last) { 459 Block *curParent = getContainingBlock(); 460 461 // Invalidate the ordering of the parent block. 462 curParent->invalidateOpOrder(); 463 464 // If we are transferring operations within the same block, the block 465 // pointer doesn't need to be updated. 466 if (curParent == otherList.getContainingBlock()) 467 return; 468 469 // Update the 'block' member of each operation. 470 for (; first != last; ++first) 471 first->block = curParent; 472 } 473 474 /// Remove this operation (and its descendants) from its Block and delete 475 /// all of them. 476 void Operation::erase() { 477 if (auto *parent = getBlock()) 478 parent->getOperations().erase(this); 479 else 480 destroy(); 481 } 482 483 /// Remove the operation from its parent block, but don't delete it. 484 void Operation::remove() { 485 if (Block *parent = getBlock()) 486 parent->getOperations().remove(this); 487 } 488 489 /// Unlink this operation from its current block and insert it right before 490 /// `existingOp` which may be in the same or another block in the same 491 /// function. 492 void Operation::moveBefore(Operation *existingOp) { 493 moveBefore(existingOp->getBlock(), existingOp->getIterator()); 494 } 495 496 /// Unlink this operation from its current basic block and insert it right 497 /// before `iterator` in the specified basic block. 498 void Operation::moveBefore(Block *block, 499 llvm::iplist<Operation>::iterator iterator) { 500 block->getOperations().splice(iterator, getBlock()->getOperations(), 501 getIterator()); 502 } 503 504 /// Unlink this operation from its current block and insert it right after 505 /// `existingOp` which may be in the same or another block in the same function. 506 void Operation::moveAfter(Operation *existingOp) { 507 moveAfter(existingOp->getBlock(), existingOp->getIterator()); 508 } 509 510 /// Unlink this operation from its current block and insert it right after 511 /// `iterator` in the specified block. 512 void Operation::moveAfter(Block *block, 513 llvm::iplist<Operation>::iterator iterator) { 514 assert(iterator != block->end() && "cannot move after end of block"); 515 moveBefore(&*std::next(iterator)); 516 } 517 518 /// This drops all operand uses from this operation, which is an essential 519 /// step in breaking cyclic dependences between references when they are to 520 /// be deleted. 521 void Operation::dropAllReferences() { 522 for (auto &op : getOpOperands()) 523 op.drop(); 524 525 for (auto ®ion : getRegions()) 526 region.dropAllReferences(); 527 528 for (auto &dest : getBlockOperands()) 529 dest.drop(); 530 } 531 532 /// This drops all uses of any values defined by this operation or its nested 533 /// regions, wherever they are located. 534 void Operation::dropAllDefinedValueUses() { 535 dropAllUses(); 536 537 for (auto ®ion : getRegions()) 538 for (auto &block : region) 539 block.dropAllDefinedValueUses(); 540 } 541 542 void Operation::setSuccessor(Block *block, unsigned index) { 543 assert(index < getNumSuccessors()); 544 getBlockOperands()[index].set(block); 545 } 546 547 /// Attempt to fold this operation using the Op's registered foldHook. 548 LogicalResult Operation::fold(ArrayRef<Attribute> operands, 549 SmallVectorImpl<OpFoldResult> &results) { 550 // If we have a registered operation definition matching this one, use it to 551 // try to constant fold the operation. 552 auto *abstractOp = getAbstractOperation(); 553 if (abstractOp && succeeded(abstractOp->foldHook(this, operands, results))) 554 return success(); 555 556 // Otherwise, fall back on the dialect hook to handle it. 557 Dialect *dialect = getDialect(); 558 if (!dialect) 559 return failure(); 560 561 auto *interface = dialect->getRegisteredInterface<DialectFoldInterface>(); 562 if (!interface) 563 return failure(); 564 565 return interface->fold(this, operands, results); 566 } 567 568 /// Emit an error with the op name prefixed, like "'dim' op " which is 569 /// convenient for verifiers. 570 InFlightDiagnostic Operation::emitOpError(const Twine &message) { 571 return emitError() << "'" << getName() << "' op " << message; 572 } 573 574 //===----------------------------------------------------------------------===// 575 // Operation Cloning 576 //===----------------------------------------------------------------------===// 577 578 /// Create a deep copy of this operation but keep the operation regions empty. 579 /// Operands are remapped using `mapper` (if present), and `mapper` is updated 580 /// to contain the results. 581 Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper) { 582 SmallVector<Value, 8> operands; 583 SmallVector<Block *, 2> successors; 584 585 // Remap the operands. 586 operands.reserve(getNumOperands()); 587 for (auto opValue : getOperands()) 588 operands.push_back(mapper.lookupOrDefault(opValue)); 589 590 // Remap the successors. 591 successors.reserve(getNumSuccessors()); 592 for (Block *successor : getSuccessors()) 593 successors.push_back(mapper.lookupOrDefault(successor)); 594 595 // Create the new operation. 596 auto *newOp = create(getLoc(), getName(), getResultTypes(), operands, attrs, 597 successors, getNumRegions()); 598 599 // Remember the mapping of any results. 600 for (unsigned i = 0, e = getNumResults(); i != e; ++i) 601 mapper.map(getResult(i), newOp->getResult(i)); 602 603 return newOp; 604 } 605 606 Operation *Operation::cloneWithoutRegions() { 607 BlockAndValueMapping mapper; 608 return cloneWithoutRegions(mapper); 609 } 610 611 /// Create a deep copy of this operation, remapping any operands that use 612 /// values outside of the operation using the map that is provided (leaving 613 /// them alone if no entry is present). Replaces references to cloned 614 /// sub-operations to the corresponding operation that is copied, and adds 615 /// those mappings to the map. 616 Operation *Operation::clone(BlockAndValueMapping &mapper) { 617 auto *newOp = cloneWithoutRegions(mapper); 618 619 // Clone the regions. 620 for (unsigned i = 0; i != numRegions; ++i) 621 getRegion(i).cloneInto(&newOp->getRegion(i), mapper); 622 623 return newOp; 624 } 625 626 Operation *Operation::clone() { 627 BlockAndValueMapping mapper; 628 return clone(mapper); 629 } 630 631 //===----------------------------------------------------------------------===// 632 // OpState trait class. 633 //===----------------------------------------------------------------------===// 634 635 // The fallback for the parser is to reject the custom assembly form. 636 ParseResult OpState::parse(OpAsmParser &parser, OperationState &result) { 637 return parser.emitError(parser.getNameLoc(), "has no custom assembly form"); 638 } 639 640 // The fallback for the printer is to print in the generic assembly form. 641 void OpState::print(Operation *op, OpAsmPrinter &p) { p.printGenericOp(op); } 642 643 /// Emit an error about fatal conditions with this operation, reporting up to 644 /// any diagnostic handlers that may be listening. 645 InFlightDiagnostic OpState::emitError(const Twine &message) { 646 return getOperation()->emitError(message); 647 } 648 649 /// Emit an error with the op name prefixed, like "'dim' op " which is 650 /// convenient for verifiers. 651 InFlightDiagnostic OpState::emitOpError(const Twine &message) { 652 return getOperation()->emitOpError(message); 653 } 654 655 /// Emit a warning about this operation, reporting up to any diagnostic 656 /// handlers that may be listening. 657 InFlightDiagnostic OpState::emitWarning(const Twine &message) { 658 return getOperation()->emitWarning(message); 659 } 660 661 /// Emit a remark about this operation, reporting up to any diagnostic 662 /// handlers that may be listening. 663 InFlightDiagnostic OpState::emitRemark(const Twine &message) { 664 return getOperation()->emitRemark(message); 665 } 666 667 //===----------------------------------------------------------------------===// 668 // Op Trait implementations 669 //===----------------------------------------------------------------------===// 670 671 OpFoldResult OpTrait::impl::foldIdempotent(Operation *op) { 672 auto *argumentOp = op->getOperand(0).getDefiningOp(); 673 if (argumentOp && op->getName() == argumentOp->getName()) { 674 // Replace the outer operation output with the inner operation. 675 return op->getOperand(0); 676 } 677 678 return {}; 679 } 680 681 OpFoldResult OpTrait::impl::foldInvolution(Operation *op) { 682 auto *argumentOp = op->getOperand(0).getDefiningOp(); 683 if (argumentOp && op->getName() == argumentOp->getName()) { 684 // Replace the outer involutions output with inner's input. 685 return argumentOp->getOperand(0); 686 } 687 688 return {}; 689 } 690 691 LogicalResult OpTrait::impl::verifyZeroOperands(Operation *op) { 692 if (op->getNumOperands() != 0) 693 return op->emitOpError() << "requires zero operands"; 694 return success(); 695 } 696 697 LogicalResult OpTrait::impl::verifyOneOperand(Operation *op) { 698 if (op->getNumOperands() != 1) 699 return op->emitOpError() << "requires a single operand"; 700 return success(); 701 } 702 703 LogicalResult OpTrait::impl::verifyNOperands(Operation *op, 704 unsigned numOperands) { 705 if (op->getNumOperands() != numOperands) { 706 return op->emitOpError() << "expected " << numOperands 707 << " operands, but found " << op->getNumOperands(); 708 } 709 return success(); 710 } 711 712 LogicalResult OpTrait::impl::verifyAtLeastNOperands(Operation *op, 713 unsigned numOperands) { 714 if (op->getNumOperands() < numOperands) 715 return op->emitOpError() 716 << "expected " << numOperands << " or more operands"; 717 return success(); 718 } 719 720 /// If this is a vector type, or a tensor type, return the scalar element type 721 /// that it is built around, otherwise return the type unmodified. 722 static Type getTensorOrVectorElementType(Type type) { 723 if (auto vec = type.dyn_cast<VectorType>()) 724 return vec.getElementType(); 725 726 // Look through tensor<vector<...>> to find the underlying element type. 727 if (auto tensor = type.dyn_cast<TensorType>()) 728 return getTensorOrVectorElementType(tensor.getElementType()); 729 return type; 730 } 731 732 LogicalResult OpTrait::impl::verifyIsIdempotent(Operation *op) { 733 // FIXME: Add back check for no side effects on operation. 734 // Currently adding it would cause the shared library build 735 // to fail since there would be a dependency of IR on SideEffectInterfaces 736 // which is cyclical. 737 return success(); 738 } 739 740 LogicalResult OpTrait::impl::verifyIsInvolution(Operation *op) { 741 // FIXME: Add back check for no side effects on operation. 742 // Currently adding it would cause the shared library build 743 // to fail since there would be a dependency of IR on SideEffectInterfaces 744 // which is cyclical. 745 return success(); 746 } 747 748 LogicalResult 749 OpTrait::impl::verifyOperandsAreSignlessIntegerLike(Operation *op) { 750 for (auto opType : op->getOperandTypes()) { 751 auto type = getTensorOrVectorElementType(opType); 752 if (!type.isSignlessIntOrIndex()) 753 return op->emitOpError() << "requires an integer or index type"; 754 } 755 return success(); 756 } 757 758 LogicalResult OpTrait::impl::verifyOperandsAreFloatLike(Operation *op) { 759 for (auto opType : op->getOperandTypes()) { 760 auto type = getTensorOrVectorElementType(opType); 761 if (!type.isa<FloatType>()) 762 return op->emitOpError("requires a float type"); 763 } 764 return success(); 765 } 766 767 LogicalResult OpTrait::impl::verifySameTypeOperands(Operation *op) { 768 // Zero or one operand always have the "same" type. 769 unsigned nOperands = op->getNumOperands(); 770 if (nOperands < 2) 771 return success(); 772 773 auto type = op->getOperand(0).getType(); 774 for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) 775 if (opType != type) 776 return op->emitOpError() << "requires all operands to have the same type"; 777 return success(); 778 } 779 780 LogicalResult OpTrait::impl::verifyZeroRegion(Operation *op) { 781 if (op->getNumRegions() != 0) 782 return op->emitOpError() << "requires zero regions"; 783 return success(); 784 } 785 786 LogicalResult OpTrait::impl::verifyOneRegion(Operation *op) { 787 if (op->getNumRegions() != 1) 788 return op->emitOpError() << "requires one region"; 789 return success(); 790 } 791 792 LogicalResult OpTrait::impl::verifyNRegions(Operation *op, 793 unsigned numRegions) { 794 if (op->getNumRegions() != numRegions) 795 return op->emitOpError() << "expected " << numRegions << " regions"; 796 return success(); 797 } 798 799 LogicalResult OpTrait::impl::verifyAtLeastNRegions(Operation *op, 800 unsigned numRegions) { 801 if (op->getNumRegions() < numRegions) 802 return op->emitOpError() << "expected " << numRegions << " or more regions"; 803 return success(); 804 } 805 806 LogicalResult OpTrait::impl::verifyZeroResult(Operation *op) { 807 if (op->getNumResults() != 0) 808 return op->emitOpError() << "requires zero results"; 809 return success(); 810 } 811 812 LogicalResult OpTrait::impl::verifyOneResult(Operation *op) { 813 if (op->getNumResults() != 1) 814 return op->emitOpError() << "requires one result"; 815 return success(); 816 } 817 818 LogicalResult OpTrait::impl::verifyNResults(Operation *op, 819 unsigned numOperands) { 820 if (op->getNumResults() != numOperands) 821 return op->emitOpError() << "expected " << numOperands << " results"; 822 return success(); 823 } 824 825 LogicalResult OpTrait::impl::verifyAtLeastNResults(Operation *op, 826 unsigned numOperands) { 827 if (op->getNumResults() < numOperands) 828 return op->emitOpError() 829 << "expected " << numOperands << " or more results"; 830 return success(); 831 } 832 833 LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) { 834 if (failed(verifyAtLeastNOperands(op, 1))) 835 return failure(); 836 837 auto type = op->getOperand(0).getType(); 838 for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) { 839 if (failed(verifyCompatibleShape(opType, type))) 840 return op->emitOpError() << "requires the same shape for all operands"; 841 } 842 return success(); 843 } 844 845 LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) { 846 if (failed(verifyAtLeastNOperands(op, 1)) || 847 failed(verifyAtLeastNResults(op, 1))) 848 return failure(); 849 850 auto type = op->getOperand(0).getType(); 851 for (auto resultType : op->getResultTypes()) { 852 if (failed(verifyCompatibleShape(resultType, type))) 853 return op->emitOpError() 854 << "requires the same shape for all operands and results"; 855 } 856 for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) { 857 if (failed(verifyCompatibleShape(opType, type))) 858 return op->emitOpError() 859 << "requires the same shape for all operands and results"; 860 } 861 return success(); 862 } 863 864 LogicalResult OpTrait::impl::verifySameOperandsElementType(Operation *op) { 865 if (failed(verifyAtLeastNOperands(op, 1))) 866 return failure(); 867 auto elementType = getElementTypeOrSelf(op->getOperand(0)); 868 869 for (auto operand : llvm::drop_begin(op->getOperands(), 1)) { 870 if (getElementTypeOrSelf(operand) != elementType) 871 return op->emitOpError("requires the same element type for all operands"); 872 } 873 874 return success(); 875 } 876 877 LogicalResult 878 OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) { 879 if (failed(verifyAtLeastNOperands(op, 1)) || 880 failed(verifyAtLeastNResults(op, 1))) 881 return failure(); 882 883 auto elementType = getElementTypeOrSelf(op->getResult(0)); 884 885 // Verify result element type matches first result's element type. 886 for (auto result : llvm::drop_begin(op->getResults(), 1)) { 887 if (getElementTypeOrSelf(result) != elementType) 888 return op->emitOpError( 889 "requires the same element type for all operands and results"); 890 } 891 892 // Verify operand's element type matches first result's element type. 893 for (auto operand : op->getOperands()) { 894 if (getElementTypeOrSelf(operand) != elementType) 895 return op->emitOpError( 896 "requires the same element type for all operands and results"); 897 } 898 899 return success(); 900 } 901 902 LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) { 903 if (failed(verifyAtLeastNOperands(op, 1)) || 904 failed(verifyAtLeastNResults(op, 1))) 905 return failure(); 906 907 auto type = op->getResult(0).getType(); 908 auto elementType = getElementTypeOrSelf(type); 909 for (auto resultType : llvm::drop_begin(op->getResultTypes())) { 910 if (getElementTypeOrSelf(resultType) != elementType || 911 failed(verifyCompatibleShape(resultType, type))) 912 return op->emitOpError() 913 << "requires the same type for all operands and results"; 914 } 915 for (auto opType : op->getOperandTypes()) { 916 if (getElementTypeOrSelf(opType) != elementType || 917 failed(verifyCompatibleShape(opType, type))) 918 return op->emitOpError() 919 << "requires the same type for all operands and results"; 920 } 921 return success(); 922 } 923 924 LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) { 925 Block *block = op->getBlock(); 926 // Verify that the operation is at the end of the respective parent block. 927 if (!block || &block->back() != op) 928 return op->emitOpError("must be the last operation in the parent block"); 929 return success(); 930 } 931 932 static LogicalResult verifyTerminatorSuccessors(Operation *op) { 933 auto *parent = op->getParentRegion(); 934 935 // Verify that the operands lines up with the BB arguments in the successor. 936 for (Block *succ : op->getSuccessors()) 937 if (succ->getParent() != parent) 938 return op->emitError("reference to block defined in another region"); 939 return success(); 940 } 941 942 LogicalResult OpTrait::impl::verifyZeroSuccessor(Operation *op) { 943 if (op->getNumSuccessors() != 0) { 944 return op->emitOpError("requires 0 successors but found ") 945 << op->getNumSuccessors(); 946 } 947 return success(); 948 } 949 950 LogicalResult OpTrait::impl::verifyOneSuccessor(Operation *op) { 951 if (op->getNumSuccessors() != 1) { 952 return op->emitOpError("requires 1 successor but found ") 953 << op->getNumSuccessors(); 954 } 955 return verifyTerminatorSuccessors(op); 956 } 957 LogicalResult OpTrait::impl::verifyNSuccessors(Operation *op, 958 unsigned numSuccessors) { 959 if (op->getNumSuccessors() != numSuccessors) { 960 return op->emitOpError("requires ") 961 << numSuccessors << " successors but found " 962 << op->getNumSuccessors(); 963 } 964 return verifyTerminatorSuccessors(op); 965 } 966 LogicalResult OpTrait::impl::verifyAtLeastNSuccessors(Operation *op, 967 unsigned numSuccessors) { 968 if (op->getNumSuccessors() < numSuccessors) { 969 return op->emitOpError("requires at least ") 970 << numSuccessors << " successors but found " 971 << op->getNumSuccessors(); 972 } 973 return verifyTerminatorSuccessors(op); 974 } 975 976 LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) { 977 for (auto resultType : op->getResultTypes()) { 978 auto elementType = getTensorOrVectorElementType(resultType); 979 bool isBoolType = elementType.isInteger(1); 980 if (!isBoolType) 981 return op->emitOpError() << "requires a bool result type"; 982 } 983 984 return success(); 985 } 986 987 LogicalResult OpTrait::impl::verifyResultsAreFloatLike(Operation *op) { 988 for (auto resultType : op->getResultTypes()) 989 if (!getTensorOrVectorElementType(resultType).isa<FloatType>()) 990 return op->emitOpError() << "requires a floating point type"; 991 992 return success(); 993 } 994 995 LogicalResult 996 OpTrait::impl::verifyResultsAreSignlessIntegerLike(Operation *op) { 997 for (auto resultType : op->getResultTypes()) 998 if (!getTensorOrVectorElementType(resultType).isSignlessIntOrIndex()) 999 return op->emitOpError() << "requires an integer or index type"; 1000 return success(); 1001 } 1002 1003 static LogicalResult verifyValueSizeAttr(Operation *op, StringRef attrName, 1004 bool isOperand) { 1005 auto sizeAttr = op->getAttrOfType<DenseIntElementsAttr>(attrName); 1006 if (!sizeAttr) 1007 return op->emitOpError("requires 1D vector attribute '") << attrName << "'"; 1008 1009 auto sizeAttrType = sizeAttr.getType().dyn_cast<VectorType>(); 1010 if (!sizeAttrType || sizeAttrType.getRank() != 1) 1011 return op->emitOpError("requires 1D vector attribute '") << attrName << "'"; 1012 1013 if (llvm::any_of(sizeAttr.getIntValues(), [](const APInt &element) { 1014 return !element.isNonNegative(); 1015 })) 1016 return op->emitOpError("'") 1017 << attrName << "' attribute cannot have negative elements"; 1018 1019 size_t totalCount = std::accumulate( 1020 sizeAttr.begin(), sizeAttr.end(), 0, 1021 [](unsigned all, APInt one) { return all + one.getZExtValue(); }); 1022 1023 if (isOperand && totalCount != op->getNumOperands()) 1024 return op->emitOpError("operand count (") 1025 << op->getNumOperands() << ") does not match with the total size (" 1026 << totalCount << ") specified in attribute '" << attrName << "'"; 1027 else if (!isOperand && totalCount != op->getNumResults()) 1028 return op->emitOpError("result count (") 1029 << op->getNumResults() << ") does not match with the total size (" 1030 << totalCount << ") specified in attribute '" << attrName << "'"; 1031 return success(); 1032 } 1033 1034 LogicalResult OpTrait::impl::verifyOperandSizeAttr(Operation *op, 1035 StringRef attrName) { 1036 return verifyValueSizeAttr(op, attrName, /*isOperand=*/true); 1037 } 1038 1039 LogicalResult OpTrait::impl::verifyResultSizeAttr(Operation *op, 1040 StringRef attrName) { 1041 return verifyValueSizeAttr(op, attrName, /*isOperand=*/false); 1042 } 1043 1044 LogicalResult OpTrait::impl::verifyNoRegionArguments(Operation *op) { 1045 for (Region ®ion : op->getRegions()) { 1046 if (region.empty()) 1047 continue; 1048 1049 if (region.getNumArguments() != 0) { 1050 if (op->getNumRegions() > 1) 1051 return op->emitOpError("region #") 1052 << region.getRegionNumber() << " should have no arguments"; 1053 else 1054 return op->emitOpError("region should have no arguments"); 1055 } 1056 } 1057 return success(); 1058 } 1059 1060 /// Checks if two ShapedTypes are the same, ignoring the element type. 1061 static bool areSameShapedTypeIgnoringElementType(ShapedType a, ShapedType b) { 1062 if (a.getTypeID() != b.getTypeID()) 1063 return false; 1064 if (!a.hasRank()) 1065 return !b.hasRank(); 1066 return a.getShape() == b.getShape(); 1067 } 1068 1069 LogicalResult OpTrait::impl::verifyElementwise(Operation *op) { 1070 auto isMappableType = [](Type type) { 1071 return type.isa<VectorType, TensorType>(); 1072 }; 1073 auto resultMappableTypes = llvm::to_vector<1>( 1074 llvm::make_filter_range(op->getResultTypes(), isMappableType)); 1075 auto operandMappableTypes = llvm::to_vector<2>( 1076 llvm::make_filter_range(op->getOperandTypes(), isMappableType)); 1077 1078 // If the op only has scalar operand/result types, then we have nothing to 1079 // check. 1080 if (resultMappableTypes.empty() && operandMappableTypes.empty()) 1081 return success(); 1082 1083 if (!resultMappableTypes.empty() && operandMappableTypes.empty()) 1084 return op->emitOpError("if a result is non-scalar, then at least one " 1085 "operand must be non-scalar"); 1086 1087 assert(!operandMappableTypes.empty()); 1088 1089 if (resultMappableTypes.empty()) 1090 return op->emitOpError("if an operand is non-scalar, then there must be at " 1091 "least one non-scalar result"); 1092 1093 if (resultMappableTypes.size() != op->getNumResults()) 1094 return op->emitOpError( 1095 "if an operand is non-scalar, then all results must be non-scalar"); 1096 1097 auto mustMatchType = operandMappableTypes[0].cast<ShapedType>(); 1098 for (auto type : 1099 llvm::concat<Type>(resultMappableTypes, operandMappableTypes)) { 1100 if (!areSameShapedTypeIgnoringElementType(type.cast<ShapedType>(), 1101 mustMatchType)) { 1102 return op->emitOpError() << "all non-scalar operands/results must have " 1103 "the same shape and base type: found " 1104 << type << " and " << mustMatchType; 1105 } 1106 } 1107 1108 return success(); 1109 } 1110 1111 bool OpTrait::hasElementwiseMappableTraits(Operation *op) { 1112 return op->hasTrait<Elementwise>() && op->hasTrait<Scalarizable>() && 1113 op->hasTrait<Vectorizable>() && op->hasTrait<Tensorizable>(); 1114 } 1115 1116 //===----------------------------------------------------------------------===// 1117 // BinaryOp implementation 1118 //===----------------------------------------------------------------------===// 1119 1120 // These functions are out-of-line implementations of the methods in BinaryOp, 1121 // which avoids them being template instantiated/duplicated. 1122 1123 void impl::buildBinaryOp(OpBuilder &builder, OperationState &result, Value lhs, 1124 Value rhs) { 1125 assert(lhs.getType() == rhs.getType()); 1126 result.addOperands({lhs, rhs}); 1127 result.types.push_back(lhs.getType()); 1128 } 1129 1130 ParseResult impl::parseOneResultSameOperandTypeOp(OpAsmParser &parser, 1131 OperationState &result) { 1132 SmallVector<OpAsmParser::OperandType, 2> ops; 1133 Type type; 1134 return failure(parser.parseOperandList(ops) || 1135 parser.parseOptionalAttrDict(result.attributes) || 1136 parser.parseColonType(type) || 1137 parser.resolveOperands(ops, type, result.operands) || 1138 parser.addTypeToList(type, result.types)); 1139 } 1140 1141 void impl::printOneResultOp(Operation *op, OpAsmPrinter &p) { 1142 assert(op->getNumResults() == 1 && "op should have one result"); 1143 1144 // If not all the operand and result types are the same, just use the 1145 // generic assembly form to avoid omitting information in printing. 1146 auto resultType = op->getResult(0).getType(); 1147 if (llvm::any_of(op->getOperandTypes(), 1148 [&](Type type) { return type != resultType; })) { 1149 p.printGenericOp(op); 1150 return; 1151 } 1152 1153 p << op->getName() << ' '; 1154 p.printOperands(op->getOperands()); 1155 p.printOptionalAttrDict(op->getAttrs()); 1156 // Now we can output only one type for all operands and the result. 1157 p << " : " << resultType; 1158 } 1159 1160 //===----------------------------------------------------------------------===// 1161 // CastOp implementation 1162 //===----------------------------------------------------------------------===// 1163 1164 /// Attempt to fold the given cast operation. 1165 LogicalResult 1166 impl::foldCastInterfaceOp(Operation *op, ArrayRef<Attribute> attrOperands, 1167 SmallVectorImpl<OpFoldResult> &foldResults) { 1168 OperandRange operands = op->getOperands(); 1169 if (operands.empty()) 1170 return failure(); 1171 ResultRange results = op->getResults(); 1172 1173 // Check for the case where the input and output types match 1-1. 1174 if (operands.getTypes() == results.getTypes()) { 1175 foldResults.append(operands.begin(), operands.end()); 1176 return success(); 1177 } 1178 1179 return failure(); 1180 } 1181 1182 /// Attempt to verify the given cast operation. 1183 LogicalResult impl::verifyCastInterfaceOp( 1184 Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible) { 1185 auto resultTypes = op->getResultTypes(); 1186 if (llvm::empty(resultTypes)) 1187 return op->emitOpError() 1188 << "expected at least one result for cast operation"; 1189 1190 auto operandTypes = op->getOperandTypes(); 1191 if (!areCastCompatible(operandTypes, resultTypes)) { 1192 InFlightDiagnostic diag = op->emitOpError("operand type"); 1193 if (llvm::empty(operandTypes)) 1194 diag << "s []"; 1195 else if (llvm::size(operandTypes) == 1) 1196 diag << " " << *operandTypes.begin(); 1197 else 1198 diag << "s " << operandTypes; 1199 return diag << " and result type" << (resultTypes.size() == 1 ? " " : "s ") 1200 << resultTypes << " are cast incompatible"; 1201 } 1202 1203 return success(); 1204 } 1205 1206 void impl::buildCastOp(OpBuilder &builder, OperationState &result, Value source, 1207 Type destType) { 1208 result.addOperands(source); 1209 result.addTypes(destType); 1210 } 1211 1212 ParseResult impl::parseCastOp(OpAsmParser &parser, OperationState &result) { 1213 OpAsmParser::OperandType srcInfo; 1214 Type srcType, dstType; 1215 return failure(parser.parseOperand(srcInfo) || 1216 parser.parseOptionalAttrDict(result.attributes) || 1217 parser.parseColonType(srcType) || 1218 parser.resolveOperand(srcInfo, srcType, result.operands) || 1219 parser.parseKeywordType("to", dstType) || 1220 parser.addTypeToList(dstType, result.types)); 1221 } 1222 1223 void impl::printCastOp(Operation *op, OpAsmPrinter &p) { 1224 p << op->getName() << ' ' << op->getOperand(0); 1225 p.printOptionalAttrDict(op->getAttrs()); 1226 p << " : " << op->getOperand(0).getType() << " to " 1227 << op->getResult(0).getType(); 1228 } 1229 1230 Value impl::foldCastOp(Operation *op) { 1231 // Identity cast 1232 if (op->getOperand(0).getType() == op->getResult(0).getType()) 1233 return op->getOperand(0); 1234 return nullptr; 1235 } 1236 1237 LogicalResult 1238 impl::verifyCastOp(Operation *op, 1239 function_ref<bool(Type, Type)> areCastCompatible) { 1240 auto opType = op->getOperand(0).getType(); 1241 auto resType = op->getResult(0).getType(); 1242 if (!areCastCompatible(opType, resType)) 1243 return op->emitError("operand type ") 1244 << opType << " and result type " << resType 1245 << " are cast incompatible"; 1246 1247 return success(); 1248 } 1249 1250 //===----------------------------------------------------------------------===// 1251 // Misc. utils 1252 //===----------------------------------------------------------------------===// 1253 1254 /// Insert an operation, generated by `buildTerminatorOp`, at the end of the 1255 /// region's only block if it does not have a terminator already. If the region 1256 /// is empty, insert a new block first. `buildTerminatorOp` should return the 1257 /// terminator operation to insert. 1258 void impl::ensureRegionTerminator( 1259 Region ®ion, OpBuilder &builder, Location loc, 1260 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { 1261 OpBuilder::InsertionGuard guard(builder); 1262 if (region.empty()) 1263 builder.createBlock(®ion); 1264 1265 Block &block = region.back(); 1266 if (!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>()) 1267 return; 1268 1269 builder.setInsertionPointToEnd(&block); 1270 builder.insert(buildTerminatorOp(builder, loc)); 1271 } 1272 1273 /// Create a simple OpBuilder and forward to the OpBuilder version of this 1274 /// function. 1275 void impl::ensureRegionTerminator( 1276 Region ®ion, Builder &builder, Location loc, 1277 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { 1278 OpBuilder opBuilder(builder.getContext()); 1279 ensureRegionTerminator(region, opBuilder, loc, buildTerminatorOp); 1280 } 1281 1282 //===----------------------------------------------------------------------===// 1283 // UseIterator 1284 //===----------------------------------------------------------------------===// 1285 1286 Operation::UseIterator::UseIterator(Operation *op, bool end) 1287 : op(op), res(end ? op->result_end() : op->result_begin()) { 1288 // Only initialize current use if there are results/can be uses. 1289 if (op->getNumResults()) 1290 skipOverResultsWithNoUsers(); 1291 } 1292 1293 Operation::UseIterator &Operation::UseIterator::operator++() { 1294 // We increment over uses, if we reach the last use then move to next 1295 // result. 1296 if (use != (*res).use_end()) 1297 ++use; 1298 if (use == (*res).use_end()) { 1299 ++res; 1300 skipOverResultsWithNoUsers(); 1301 } 1302 return *this; 1303 } 1304 1305 void Operation::UseIterator::skipOverResultsWithNoUsers() { 1306 while (res != op->result_end() && (*res).use_empty()) 1307 ++res; 1308 1309 // If we are at the last result, then set use to first use of 1310 // first result (sentinel value used for end). 1311 if (res == op->result_end()) 1312 use = {}; 1313 else 1314 use = (*res).use_begin(); 1315 } 1316