1 //===- Operation.cpp - Operation support code -----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Operation.h" 10 #include "mlir/IR/BlockAndValueMapping.h" 11 #include "mlir/IR/BuiltinTypes.h" 12 #include "mlir/IR/Dialect.h" 13 #include "mlir/IR/OpImplementation.h" 14 #include "mlir/IR/PatternMatch.h" 15 #include "mlir/IR/TypeUtilities.h" 16 #include "mlir/Interfaces/FoldInterfaces.h" 17 #include <numeric> 18 19 using namespace mlir; 20 21 OpAsmParser::~OpAsmParser() {} 22 23 //===----------------------------------------------------------------------===// 24 // OperationName 25 //===----------------------------------------------------------------------===// 26 27 /// Form the OperationName for an op with the specified string. This either is 28 /// a reference to an AbstractOperation if one is known, or a uniqued Identifier 29 /// if not. 30 OperationName::OperationName(StringRef name, MLIRContext *context) { 31 if (auto *op = AbstractOperation::lookup(name, context)) 32 representation = op; 33 else 34 representation = Identifier::get(name, context); 35 } 36 37 /// Return the name of the dialect this operation is registered to. 38 StringRef OperationName::getDialectNamespace() const { 39 if (Dialect *dialect = getDialect()) 40 return dialect->getNamespace(); 41 return representation.get<Identifier>().strref().split('.').first; 42 } 43 44 /// Return the operation name with dialect name stripped, if it has one. 45 StringRef OperationName::stripDialect() const { 46 auto splitName = getStringRef().split("."); 47 return splitName.second.empty() ? splitName.first : splitName.second; 48 } 49 50 /// Return the name of this operation. This always succeeds. 51 StringRef OperationName::getStringRef() const { 52 return getIdentifier().strref(); 53 } 54 55 /// Return the name of this operation as an identifier. This always succeeds. 56 Identifier OperationName::getIdentifier() const { 57 if (auto *op = representation.dyn_cast<const AbstractOperation *>()) 58 return op->name; 59 return representation.get<Identifier>(); 60 } 61 62 OperationName OperationName::getFromOpaquePointer(const void *pointer) { 63 return OperationName( 64 RepresentationUnion::getFromOpaqueValue(const_cast<void *>(pointer))); 65 } 66 67 //===----------------------------------------------------------------------===// 68 // Operation 69 //===----------------------------------------------------------------------===// 70 71 /// Create a new Operation with the specific fields. 72 Operation *Operation::create(Location location, OperationName name, 73 TypeRange resultTypes, ValueRange operands, 74 ArrayRef<NamedAttribute> attributes, 75 BlockRange successors, unsigned numRegions) { 76 return create(location, name, resultTypes, operands, 77 DictionaryAttr::get(location.getContext(), attributes), 78 successors, numRegions); 79 } 80 81 /// Create a new Operation from operation state. 82 Operation *Operation::create(const OperationState &state) { 83 return create(state.location, state.name, state.types, state.operands, 84 state.attributes.getDictionary(state.getContext()), 85 state.successors, state.regions); 86 } 87 88 /// Create a new Operation with the specific fields. 89 Operation *Operation::create(Location location, OperationName name, 90 TypeRange resultTypes, ValueRange operands, 91 DictionaryAttr attributes, BlockRange successors, 92 RegionRange regions) { 93 unsigned numRegions = regions.size(); 94 Operation *op = create(location, name, resultTypes, operands, attributes, 95 successors, numRegions); 96 for (unsigned i = 0; i < numRegions; ++i) 97 if (regions[i]) 98 op->getRegion(i).takeBody(*regions[i]); 99 return op; 100 } 101 102 /// Overload of create that takes an existing DictionaryAttr to avoid 103 /// unnecessarily uniquing a list of attributes. 104 Operation *Operation::create(Location location, OperationName name, 105 TypeRange resultTypes, ValueRange operands, 106 DictionaryAttr attributes, BlockRange successors, 107 unsigned numRegions) { 108 assert(llvm::all_of(resultTypes, [](Type t) { return t; }) && 109 "unexpected null result type"); 110 111 // We only need to allocate additional memory for a subset of results. 112 unsigned numTrailingResults = OpResult::getNumTrailing(resultTypes.size()); 113 unsigned numInlineResults = OpResult::getNumInline(resultTypes.size()); 114 unsigned numSuccessors = successors.size(); 115 unsigned numOperands = operands.size(); 116 unsigned numResults = resultTypes.size(); 117 118 // If the operation is known to have no operands, don't allocate an operand 119 // storage. 120 bool needsOperandStorage = true; 121 if (operands.empty()) { 122 if (const AbstractOperation *abstractOp = name.getAbstractOperation()) 123 needsOperandStorage = !abstractOp->hasTrait<OpTrait::ZeroOperands>(); 124 } 125 126 // Compute the byte size for the operation and the operand storage. This takes 127 // into account the size of the operation, its trailing objects, and its 128 // prefixed objects. 129 size_t byteSize = 130 totalSizeToAlloc<BlockOperand, Region, detail::OperandStorage>( 131 numSuccessors, numRegions, needsOperandStorage ? 1 : 0) + 132 detail::OperandStorage::additionalAllocSize(numOperands); 133 size_t prefixByteSize = llvm::alignTo( 134 Operation::prefixAllocSize(numTrailingResults, numInlineResults), 135 alignof(Operation)); 136 char *mallocMem = reinterpret_cast<char *>(malloc(byteSize + prefixByteSize)); 137 void *rawMem = mallocMem + prefixByteSize; 138 139 // Create the new Operation. 140 Operation *op = 141 ::new (rawMem) Operation(location, name, numResults, numSuccessors, 142 numRegions, attributes, needsOperandStorage); 143 144 assert((numSuccessors == 0 || op->mightHaveTrait<OpTrait::IsTerminator>()) && 145 "unexpected successors in a non-terminator operation"); 146 147 // Initialize the results. 148 auto resultTypeIt = resultTypes.begin(); 149 for (unsigned i = 0; i < numInlineResults; ++i, ++resultTypeIt) 150 new (op->getInlineOpResult(i)) detail::InlineOpResult(*resultTypeIt, i); 151 for (unsigned i = 0; i < numTrailingResults; ++i, ++resultTypeIt) { 152 new (op->getOutOfLineOpResult(i)) 153 detail::OutOfLineOpResult(*resultTypeIt, i); 154 } 155 156 // Initialize the regions. 157 for (unsigned i = 0; i != numRegions; ++i) 158 new (&op->getRegion(i)) Region(op); 159 160 // Initialize the operands. 161 if (needsOperandStorage) 162 new (&op->getOperandStorage()) detail::OperandStorage(op, operands); 163 164 // Initialize the successors. 165 auto blockOperands = op->getBlockOperands(); 166 for (unsigned i = 0; i != numSuccessors; ++i) 167 new (&blockOperands[i]) BlockOperand(op, successors[i]); 168 169 return op; 170 } 171 172 Operation::Operation(Location location, OperationName name, unsigned numResults, 173 unsigned numSuccessors, unsigned numRegions, 174 DictionaryAttr attributes, bool hasOperandStorage) 175 : location(location), numResults(numResults), numSuccs(numSuccessors), 176 numRegions(numRegions), hasOperandStorage(hasOperandStorage), name(name), 177 attrs(attributes) { 178 assert(attributes && "unexpected null attribute dictionary"); 179 } 180 181 // Operations are deleted through the destroy() member because they are 182 // allocated via malloc. 183 Operation::~Operation() { 184 assert(block == nullptr && "operation destroyed but still in a block"); 185 #ifndef NDEBUG 186 if (!use_empty()) { 187 { 188 InFlightDiagnostic diag = 189 emitOpError("operation destroyed but still has uses"); 190 for (Operation *user : getUsers()) 191 diag.attachNote(user->getLoc()) << "- use: " << *user << "\n"; 192 } 193 llvm::report_fatal_error("operation destroyed but still has uses"); 194 } 195 #endif 196 // Explicitly run the destructors for the operands. 197 if (hasOperandStorage) 198 getOperandStorage().~OperandStorage(); 199 200 // Explicitly run the destructors for the successors. 201 for (auto &successor : getBlockOperands()) 202 successor.~BlockOperand(); 203 204 // Explicitly destroy the regions. 205 for (auto ®ion : getRegions()) 206 region.~Region(); 207 } 208 209 /// Destroy this operation or one of its subclasses. 210 void Operation::destroy() { 211 // Operations may have additional prefixed allocation, which needs to be 212 // accounted for here when computing the address to free. 213 char *rawMem = reinterpret_cast<char *>(this) - 214 llvm::alignTo(prefixAllocSize(), alignof(Operation)); 215 this->~Operation(); 216 free(rawMem); 217 } 218 219 /// Return true if this operation is a proper ancestor of the `other` 220 /// operation. 221 bool Operation::isProperAncestor(Operation *other) { 222 while ((other = other->getParentOp())) 223 if (this == other) 224 return true; 225 return false; 226 } 227 228 /// Replace any uses of 'from' with 'to' within this operation. 229 void Operation::replaceUsesOfWith(Value from, Value to) { 230 if (from == to) 231 return; 232 for (auto &operand : getOpOperands()) 233 if (operand.get() == from) 234 operand.set(to); 235 } 236 237 /// Replace the current operands of this operation with the ones provided in 238 /// 'operands'. 239 void Operation::setOperands(ValueRange operands) { 240 if (LLVM_LIKELY(hasOperandStorage)) 241 return getOperandStorage().setOperands(this, operands); 242 assert(operands.empty() && "setting operands without an operand storage"); 243 } 244 245 /// Replace the operands beginning at 'start' and ending at 'start' + 'length' 246 /// with the ones provided in 'operands'. 'operands' may be smaller or larger 247 /// than the range pointed to by 'start'+'length'. 248 void Operation::setOperands(unsigned start, unsigned length, 249 ValueRange operands) { 250 assert((start + length) <= getNumOperands() && 251 "invalid operand range specified"); 252 if (LLVM_LIKELY(hasOperandStorage)) 253 return getOperandStorage().setOperands(this, start, length, operands); 254 assert(operands.empty() && "setting operands without an operand storage"); 255 } 256 257 /// Insert the given operands into the operand list at the given 'index'. 258 void Operation::insertOperands(unsigned index, ValueRange operands) { 259 if (LLVM_LIKELY(hasOperandStorage)) 260 return setOperands(index, /*length=*/0, operands); 261 assert(operands.empty() && "inserting operands without an operand storage"); 262 } 263 264 //===----------------------------------------------------------------------===// 265 // Diagnostics 266 //===----------------------------------------------------------------------===// 267 268 /// Emit an error about fatal conditions with this operation, reporting up to 269 /// any diagnostic handlers that may be listening. 270 InFlightDiagnostic Operation::emitError(const Twine &message) { 271 InFlightDiagnostic diag = mlir::emitError(getLoc(), message); 272 if (getContext()->shouldPrintOpOnDiagnostic()) { 273 // Print out the operation explicitly here so that we can print the generic 274 // form. 275 // TODO: It would be nice if we could instead provide the 276 // specific printing flags when adding the operation as an argument to the 277 // diagnostic. 278 std::string printedOp; 279 { 280 llvm::raw_string_ostream os(printedOp); 281 print(os, OpPrintingFlags().printGenericOpForm().useLocalScope()); 282 } 283 diag.attachNote(getLoc()) << "see current operation: " << printedOp; 284 } 285 return diag; 286 } 287 288 /// Emit a warning about this operation, reporting up to any diagnostic 289 /// handlers that may be listening. 290 InFlightDiagnostic Operation::emitWarning(const Twine &message) { 291 InFlightDiagnostic diag = mlir::emitWarning(getLoc(), message); 292 if (getContext()->shouldPrintOpOnDiagnostic()) 293 diag.attachNote(getLoc()) << "see current operation: " << *this; 294 return diag; 295 } 296 297 /// Emit a remark about this operation, reporting up to any diagnostic 298 /// handlers that may be listening. 299 InFlightDiagnostic Operation::emitRemark(const Twine &message) { 300 InFlightDiagnostic diag = mlir::emitRemark(getLoc(), message); 301 if (getContext()->shouldPrintOpOnDiagnostic()) 302 diag.attachNote(getLoc()) << "see current operation: " << *this; 303 return diag; 304 } 305 306 //===----------------------------------------------------------------------===// 307 // Operation Ordering 308 //===----------------------------------------------------------------------===// 309 310 constexpr unsigned Operation::kInvalidOrderIdx; 311 constexpr unsigned Operation::kOrderStride; 312 313 /// Given an operation 'other' that is within the same parent block, return 314 /// whether the current operation is before 'other' in the operation list 315 /// of the parent block. 316 /// Note: This function has an average complexity of O(1), but worst case may 317 /// take O(N) where N is the number of operations within the parent block. 318 bool Operation::isBeforeInBlock(Operation *other) { 319 assert(block && "Operations without parent blocks have no order."); 320 assert(other && other->block == block && 321 "Expected other operation to have the same parent block."); 322 // If the order of the block is already invalid, directly recompute the 323 // parent. 324 if (!block->isOpOrderValid()) { 325 block->recomputeOpOrder(); 326 } else { 327 // Update the order either operation if necessary. 328 updateOrderIfNecessary(); 329 other->updateOrderIfNecessary(); 330 } 331 332 return orderIndex < other->orderIndex; 333 } 334 335 /// Update the order index of this operation of this operation if necessary, 336 /// potentially recomputing the order of the parent block. 337 void Operation::updateOrderIfNecessary() { 338 assert(block && "expected valid parent"); 339 340 // If the order is valid for this operation there is nothing to do. 341 if (hasValidOrder()) 342 return; 343 Operation *blockFront = &block->front(); 344 Operation *blockBack = &block->back(); 345 346 // This method is expected to only be invoked on blocks with more than one 347 // operation. 348 assert(blockFront != blockBack && "expected more than one operation"); 349 350 // If the operation is at the end of the block. 351 if (this == blockBack) { 352 Operation *prevNode = getPrevNode(); 353 if (!prevNode->hasValidOrder()) 354 return block->recomputeOpOrder(); 355 356 // Add the stride to the previous operation. 357 orderIndex = prevNode->orderIndex + kOrderStride; 358 return; 359 } 360 361 // If this is the first operation try to use the next operation to compute the 362 // ordering. 363 if (this == blockFront) { 364 Operation *nextNode = getNextNode(); 365 if (!nextNode->hasValidOrder()) 366 return block->recomputeOpOrder(); 367 // There is no order to give this operation. 368 if (nextNode->orderIndex == 0) 369 return block->recomputeOpOrder(); 370 371 // If we can't use the stride, just take the middle value left. This is safe 372 // because we know there is at least one valid index to assign to. 373 if (nextNode->orderIndex <= kOrderStride) 374 orderIndex = (nextNode->orderIndex / 2); 375 else 376 orderIndex = kOrderStride; 377 return; 378 } 379 380 // Otherwise, this operation is between two others. Place this operation in 381 // the middle of the previous and next if possible. 382 Operation *prevNode = getPrevNode(), *nextNode = getNextNode(); 383 if (!prevNode->hasValidOrder() || !nextNode->hasValidOrder()) 384 return block->recomputeOpOrder(); 385 unsigned prevOrder = prevNode->orderIndex, nextOrder = nextNode->orderIndex; 386 387 // Check to see if there is a valid order between the two. 388 if (prevOrder + 1 == nextOrder) 389 return block->recomputeOpOrder(); 390 orderIndex = prevOrder + ((nextOrder - prevOrder) / 2); 391 } 392 393 //===----------------------------------------------------------------------===// 394 // ilist_traits for Operation 395 //===----------------------------------------------------------------------===// 396 397 auto llvm::ilist_detail::SpecificNodeAccess< 398 typename llvm::ilist_detail::compute_node_options< 399 ::mlir::Operation>::type>::getNodePtr(pointer N) -> node_type * { 400 return NodeAccess::getNodePtr<OptionsT>(N); 401 } 402 403 auto llvm::ilist_detail::SpecificNodeAccess< 404 typename llvm::ilist_detail::compute_node_options< 405 ::mlir::Operation>::type>::getNodePtr(const_pointer N) 406 -> const node_type * { 407 return NodeAccess::getNodePtr<OptionsT>(N); 408 } 409 410 auto llvm::ilist_detail::SpecificNodeAccess< 411 typename llvm::ilist_detail::compute_node_options< 412 ::mlir::Operation>::type>::getValuePtr(node_type *N) -> pointer { 413 return NodeAccess::getValuePtr<OptionsT>(N); 414 } 415 416 auto llvm::ilist_detail::SpecificNodeAccess< 417 typename llvm::ilist_detail::compute_node_options< 418 ::mlir::Operation>::type>::getValuePtr(const node_type *N) 419 -> const_pointer { 420 return NodeAccess::getValuePtr<OptionsT>(N); 421 } 422 423 void llvm::ilist_traits<::mlir::Operation>::deleteNode(Operation *op) { 424 op->destroy(); 425 } 426 427 Block *llvm::ilist_traits<::mlir::Operation>::getContainingBlock() { 428 size_t Offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr)))); 429 iplist<Operation> *Anchor(static_cast<iplist<Operation> *>(this)); 430 return reinterpret_cast<Block *>(reinterpret_cast<char *>(Anchor) - Offset); 431 } 432 433 /// This is a trait method invoked when an operation is added to a block. We 434 /// keep the block pointer up to date. 435 void llvm::ilist_traits<::mlir::Operation>::addNodeToList(Operation *op) { 436 assert(!op->getBlock() && "already in an operation block!"); 437 op->block = getContainingBlock(); 438 439 // Invalidate the order on the operation. 440 op->orderIndex = Operation::kInvalidOrderIdx; 441 } 442 443 /// This is a trait method invoked when an operation is removed from a block. 444 /// We keep the block pointer up to date. 445 void llvm::ilist_traits<::mlir::Operation>::removeNodeFromList(Operation *op) { 446 assert(op->block && "not already in an operation block!"); 447 op->block = nullptr; 448 } 449 450 /// This is a trait method invoked when an operation is moved from one block 451 /// to another. We keep the block pointer up to date. 452 void llvm::ilist_traits<::mlir::Operation>::transferNodesFromList( 453 ilist_traits<Operation> &otherList, op_iterator first, op_iterator last) { 454 Block *curParent = getContainingBlock(); 455 456 // Invalidate the ordering of the parent block. 457 curParent->invalidateOpOrder(); 458 459 // If we are transferring operations within the same block, the block 460 // pointer doesn't need to be updated. 461 if (curParent == otherList.getContainingBlock()) 462 return; 463 464 // Update the 'block' member of each operation. 465 for (; first != last; ++first) 466 first->block = curParent; 467 } 468 469 /// Remove this operation (and its descendants) from its Block and delete 470 /// all of them. 471 void Operation::erase() { 472 if (auto *parent = getBlock()) 473 parent->getOperations().erase(this); 474 else 475 destroy(); 476 } 477 478 /// Remove the operation from its parent block, but don't delete it. 479 void Operation::remove() { 480 if (Block *parent = getBlock()) 481 parent->getOperations().remove(this); 482 } 483 484 /// Unlink this operation from its current block and insert it right before 485 /// `existingOp` which may be in the same or another block in the same 486 /// function. 487 void Operation::moveBefore(Operation *existingOp) { 488 moveBefore(existingOp->getBlock(), existingOp->getIterator()); 489 } 490 491 /// Unlink this operation from its current basic block and insert it right 492 /// before `iterator` in the specified basic block. 493 void Operation::moveBefore(Block *block, 494 llvm::iplist<Operation>::iterator iterator) { 495 block->getOperations().splice(iterator, getBlock()->getOperations(), 496 getIterator()); 497 } 498 499 /// Unlink this operation from its current block and insert it right after 500 /// `existingOp` which may be in the same or another block in the same function. 501 void Operation::moveAfter(Operation *existingOp) { 502 moveAfter(existingOp->getBlock(), existingOp->getIterator()); 503 } 504 505 /// Unlink this operation from its current block and insert it right after 506 /// `iterator` in the specified block. 507 void Operation::moveAfter(Block *block, 508 llvm::iplist<Operation>::iterator iterator) { 509 assert(iterator != block->end() && "cannot move after end of block"); 510 moveBefore(&*std::next(iterator)); 511 } 512 513 /// This drops all operand uses from this operation, which is an essential 514 /// step in breaking cyclic dependences between references when they are to 515 /// be deleted. 516 void Operation::dropAllReferences() { 517 for (auto &op : getOpOperands()) 518 op.drop(); 519 520 for (auto ®ion : getRegions()) 521 region.dropAllReferences(); 522 523 for (auto &dest : getBlockOperands()) 524 dest.drop(); 525 } 526 527 /// This drops all uses of any values defined by this operation or its nested 528 /// regions, wherever they are located. 529 void Operation::dropAllDefinedValueUses() { 530 dropAllUses(); 531 532 for (auto ®ion : getRegions()) 533 for (auto &block : region) 534 block.dropAllDefinedValueUses(); 535 } 536 537 void Operation::setSuccessor(Block *block, unsigned index) { 538 assert(index < getNumSuccessors()); 539 getBlockOperands()[index].set(block); 540 } 541 542 /// Attempt to fold this operation using the Op's registered foldHook. 543 LogicalResult Operation::fold(ArrayRef<Attribute> operands, 544 SmallVectorImpl<OpFoldResult> &results) { 545 // If we have a registered operation definition matching this one, use it to 546 // try to constant fold the operation. 547 auto *abstractOp = getAbstractOperation(); 548 if (abstractOp && succeeded(abstractOp->foldHook(this, operands, results))) 549 return success(); 550 551 // Otherwise, fall back on the dialect hook to handle it. 552 Dialect *dialect = getDialect(); 553 if (!dialect) 554 return failure(); 555 556 auto *interface = dialect->getRegisteredInterface<DialectFoldInterface>(); 557 if (!interface) 558 return failure(); 559 560 return interface->fold(this, operands, results); 561 } 562 563 /// Emit an error with the op name prefixed, like "'dim' op " which is 564 /// convenient for verifiers. 565 InFlightDiagnostic Operation::emitOpError(const Twine &message) { 566 return emitError() << "'" << getName() << "' op " << message; 567 } 568 569 //===----------------------------------------------------------------------===// 570 // Operation Cloning 571 //===----------------------------------------------------------------------===// 572 573 /// Create a deep copy of this operation but keep the operation regions empty. 574 /// Operands are remapped using `mapper` (if present), and `mapper` is updated 575 /// to contain the results. 576 Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper) { 577 SmallVector<Value, 8> operands; 578 SmallVector<Block *, 2> successors; 579 580 // Remap the operands. 581 operands.reserve(getNumOperands()); 582 for (auto opValue : getOperands()) 583 operands.push_back(mapper.lookupOrDefault(opValue)); 584 585 // Remap the successors. 586 successors.reserve(getNumSuccessors()); 587 for (Block *successor : getSuccessors()) 588 successors.push_back(mapper.lookupOrDefault(successor)); 589 590 // Create the new operation. 591 auto *newOp = create(getLoc(), getName(), getResultTypes(), operands, attrs, 592 successors, getNumRegions()); 593 594 // Remember the mapping of any results. 595 for (unsigned i = 0, e = getNumResults(); i != e; ++i) 596 mapper.map(getResult(i), newOp->getResult(i)); 597 598 return newOp; 599 } 600 601 Operation *Operation::cloneWithoutRegions() { 602 BlockAndValueMapping mapper; 603 return cloneWithoutRegions(mapper); 604 } 605 606 /// Create a deep copy of this operation, remapping any operands that use 607 /// values outside of the operation using the map that is provided (leaving 608 /// them alone if no entry is present). Replaces references to cloned 609 /// sub-operations to the corresponding operation that is copied, and adds 610 /// those mappings to the map. 611 Operation *Operation::clone(BlockAndValueMapping &mapper) { 612 auto *newOp = cloneWithoutRegions(mapper); 613 614 // Clone the regions. 615 for (unsigned i = 0; i != numRegions; ++i) 616 getRegion(i).cloneInto(&newOp->getRegion(i), mapper); 617 618 return newOp; 619 } 620 621 Operation *Operation::clone() { 622 BlockAndValueMapping mapper; 623 return clone(mapper); 624 } 625 626 //===----------------------------------------------------------------------===// 627 // OpState trait class. 628 //===----------------------------------------------------------------------===// 629 630 // The fallback for the parser is to reject the custom assembly form. 631 ParseResult OpState::parse(OpAsmParser &parser, OperationState &result) { 632 return parser.emitError(parser.getNameLoc(), "has no custom assembly form"); 633 } 634 635 // The fallback for the printer is to print in the generic assembly form. 636 void OpState::print(Operation *op, OpAsmPrinter &p) { p.printGenericOp(op); } 637 638 /// Emit an error about fatal conditions with this operation, reporting up to 639 /// any diagnostic handlers that may be listening. 640 InFlightDiagnostic OpState::emitError(const Twine &message) { 641 return getOperation()->emitError(message); 642 } 643 644 /// Emit an error with the op name prefixed, like "'dim' op " which is 645 /// convenient for verifiers. 646 InFlightDiagnostic OpState::emitOpError(const Twine &message) { 647 return getOperation()->emitOpError(message); 648 } 649 650 /// Emit a warning about this operation, reporting up to any diagnostic 651 /// handlers that may be listening. 652 InFlightDiagnostic OpState::emitWarning(const Twine &message) { 653 return getOperation()->emitWarning(message); 654 } 655 656 /// Emit a remark about this operation, reporting up to any diagnostic 657 /// handlers that may be listening. 658 InFlightDiagnostic OpState::emitRemark(const Twine &message) { 659 return getOperation()->emitRemark(message); 660 } 661 662 //===----------------------------------------------------------------------===// 663 // Op Trait implementations 664 //===----------------------------------------------------------------------===// 665 666 OpFoldResult OpTrait::impl::foldIdempotent(Operation *op) { 667 auto *argumentOp = op->getOperand(0).getDefiningOp(); 668 if (argumentOp && op->getName() == argumentOp->getName()) { 669 // Replace the outer operation output with the inner operation. 670 return op->getOperand(0); 671 } 672 673 return {}; 674 } 675 676 OpFoldResult OpTrait::impl::foldInvolution(Operation *op) { 677 auto *argumentOp = op->getOperand(0).getDefiningOp(); 678 if (argumentOp && op->getName() == argumentOp->getName()) { 679 // Replace the outer involutions output with inner's input. 680 return argumentOp->getOperand(0); 681 } 682 683 return {}; 684 } 685 686 LogicalResult OpTrait::impl::verifyZeroOperands(Operation *op) { 687 if (op->getNumOperands() != 0) 688 return op->emitOpError() << "requires zero operands"; 689 return success(); 690 } 691 692 LogicalResult OpTrait::impl::verifyOneOperand(Operation *op) { 693 if (op->getNumOperands() != 1) 694 return op->emitOpError() << "requires a single operand"; 695 return success(); 696 } 697 698 LogicalResult OpTrait::impl::verifyNOperands(Operation *op, 699 unsigned numOperands) { 700 if (op->getNumOperands() != numOperands) { 701 return op->emitOpError() << "expected " << numOperands 702 << " operands, but found " << op->getNumOperands(); 703 } 704 return success(); 705 } 706 707 LogicalResult OpTrait::impl::verifyAtLeastNOperands(Operation *op, 708 unsigned numOperands) { 709 if (op->getNumOperands() < numOperands) 710 return op->emitOpError() 711 << "expected " << numOperands << " or more operands"; 712 return success(); 713 } 714 715 /// If this is a vector type, or a tensor type, return the scalar element type 716 /// that it is built around, otherwise return the type unmodified. 717 static Type getTensorOrVectorElementType(Type type) { 718 if (auto vec = type.dyn_cast<VectorType>()) 719 return vec.getElementType(); 720 721 // Look through tensor<vector<...>> to find the underlying element type. 722 if (auto tensor = type.dyn_cast<TensorType>()) 723 return getTensorOrVectorElementType(tensor.getElementType()); 724 return type; 725 } 726 727 LogicalResult OpTrait::impl::verifyIsIdempotent(Operation *op) { 728 // FIXME: Add back check for no side effects on operation. 729 // Currently adding it would cause the shared library build 730 // to fail since there would be a dependency of IR on SideEffectInterfaces 731 // which is cyclical. 732 return success(); 733 } 734 735 LogicalResult OpTrait::impl::verifyIsInvolution(Operation *op) { 736 // FIXME: Add back check for no side effects on operation. 737 // Currently adding it would cause the shared library build 738 // to fail since there would be a dependency of IR on SideEffectInterfaces 739 // which is cyclical. 740 return success(); 741 } 742 743 LogicalResult 744 OpTrait::impl::verifyOperandsAreSignlessIntegerLike(Operation *op) { 745 for (auto opType : op->getOperandTypes()) { 746 auto type = getTensorOrVectorElementType(opType); 747 if (!type.isSignlessIntOrIndex()) 748 return op->emitOpError() << "requires an integer or index type"; 749 } 750 return success(); 751 } 752 753 LogicalResult OpTrait::impl::verifyOperandsAreFloatLike(Operation *op) { 754 for (auto opType : op->getOperandTypes()) { 755 auto type = getTensorOrVectorElementType(opType); 756 if (!type.isa<FloatType>()) 757 return op->emitOpError("requires a float type"); 758 } 759 return success(); 760 } 761 762 LogicalResult OpTrait::impl::verifySameTypeOperands(Operation *op) { 763 // Zero or one operand always have the "same" type. 764 unsigned nOperands = op->getNumOperands(); 765 if (nOperands < 2) 766 return success(); 767 768 auto type = op->getOperand(0).getType(); 769 for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) 770 if (opType != type) 771 return op->emitOpError() << "requires all operands to have the same type"; 772 return success(); 773 } 774 775 LogicalResult OpTrait::impl::verifyZeroRegion(Operation *op) { 776 if (op->getNumRegions() != 0) 777 return op->emitOpError() << "requires zero regions"; 778 return success(); 779 } 780 781 LogicalResult OpTrait::impl::verifyOneRegion(Operation *op) { 782 if (op->getNumRegions() != 1) 783 return op->emitOpError() << "requires one region"; 784 return success(); 785 } 786 787 LogicalResult OpTrait::impl::verifyNRegions(Operation *op, 788 unsigned numRegions) { 789 if (op->getNumRegions() != numRegions) 790 return op->emitOpError() << "expected " << numRegions << " regions"; 791 return success(); 792 } 793 794 LogicalResult OpTrait::impl::verifyAtLeastNRegions(Operation *op, 795 unsigned numRegions) { 796 if (op->getNumRegions() < numRegions) 797 return op->emitOpError() << "expected " << numRegions << " or more regions"; 798 return success(); 799 } 800 801 LogicalResult OpTrait::impl::verifyZeroResult(Operation *op) { 802 if (op->getNumResults() != 0) 803 return op->emitOpError() << "requires zero results"; 804 return success(); 805 } 806 807 LogicalResult OpTrait::impl::verifyOneResult(Operation *op) { 808 if (op->getNumResults() != 1) 809 return op->emitOpError() << "requires one result"; 810 return success(); 811 } 812 813 LogicalResult OpTrait::impl::verifyNResults(Operation *op, 814 unsigned numOperands) { 815 if (op->getNumResults() != numOperands) 816 return op->emitOpError() << "expected " << numOperands << " results"; 817 return success(); 818 } 819 820 LogicalResult OpTrait::impl::verifyAtLeastNResults(Operation *op, 821 unsigned numOperands) { 822 if (op->getNumResults() < numOperands) 823 return op->emitOpError() 824 << "expected " << numOperands << " or more results"; 825 return success(); 826 } 827 828 LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) { 829 if (failed(verifyAtLeastNOperands(op, 1))) 830 return failure(); 831 832 if (failed(verifyCompatibleShapes(op->getOperandTypes()))) 833 return op->emitOpError() << "requires the same shape for all operands"; 834 835 return success(); 836 } 837 838 LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) { 839 if (failed(verifyAtLeastNOperands(op, 1)) || 840 failed(verifyAtLeastNResults(op, 1))) 841 return failure(); 842 843 SmallVector<Type, 8> types(op->getOperandTypes()); 844 types.append(llvm::to_vector<4>(op->getResultTypes())); 845 846 if (failed(verifyCompatibleShapes(types))) 847 return op->emitOpError() 848 << "requires the same shape for all operands and results"; 849 850 return success(); 851 } 852 853 LogicalResult OpTrait::impl::verifySameOperandsElementType(Operation *op) { 854 if (failed(verifyAtLeastNOperands(op, 1))) 855 return failure(); 856 auto elementType = getElementTypeOrSelf(op->getOperand(0)); 857 858 for (auto operand : llvm::drop_begin(op->getOperands(), 1)) { 859 if (getElementTypeOrSelf(operand) != elementType) 860 return op->emitOpError("requires the same element type for all operands"); 861 } 862 863 return success(); 864 } 865 866 LogicalResult 867 OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) { 868 if (failed(verifyAtLeastNOperands(op, 1)) || 869 failed(verifyAtLeastNResults(op, 1))) 870 return failure(); 871 872 auto elementType = getElementTypeOrSelf(op->getResult(0)); 873 874 // Verify result element type matches first result's element type. 875 for (auto result : llvm::drop_begin(op->getResults(), 1)) { 876 if (getElementTypeOrSelf(result) != elementType) 877 return op->emitOpError( 878 "requires the same element type for all operands and results"); 879 } 880 881 // Verify operand's element type matches first result's element type. 882 for (auto operand : op->getOperands()) { 883 if (getElementTypeOrSelf(operand) != elementType) 884 return op->emitOpError( 885 "requires the same element type for all operands and results"); 886 } 887 888 return success(); 889 } 890 891 LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) { 892 if (failed(verifyAtLeastNOperands(op, 1)) || 893 failed(verifyAtLeastNResults(op, 1))) 894 return failure(); 895 896 auto type = op->getResult(0).getType(); 897 auto elementType = getElementTypeOrSelf(type); 898 for (auto resultType : llvm::drop_begin(op->getResultTypes())) { 899 if (getElementTypeOrSelf(resultType) != elementType || 900 failed(verifyCompatibleShape(resultType, type))) 901 return op->emitOpError() 902 << "requires the same type for all operands and results"; 903 } 904 for (auto opType : op->getOperandTypes()) { 905 if (getElementTypeOrSelf(opType) != elementType || 906 failed(verifyCompatibleShape(opType, type))) 907 return op->emitOpError() 908 << "requires the same type for all operands and results"; 909 } 910 return success(); 911 } 912 913 LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) { 914 Block *block = op->getBlock(); 915 // Verify that the operation is at the end of the respective parent block. 916 if (!block || &block->back() != op) 917 return op->emitOpError("must be the last operation in the parent block"); 918 return success(); 919 } 920 921 static LogicalResult verifyTerminatorSuccessors(Operation *op) { 922 auto *parent = op->getParentRegion(); 923 924 // Verify that the operands lines up with the BB arguments in the successor. 925 for (Block *succ : op->getSuccessors()) 926 if (succ->getParent() != parent) 927 return op->emitError("reference to block defined in another region"); 928 return success(); 929 } 930 931 LogicalResult OpTrait::impl::verifyZeroSuccessor(Operation *op) { 932 if (op->getNumSuccessors() != 0) { 933 return op->emitOpError("requires 0 successors but found ") 934 << op->getNumSuccessors(); 935 } 936 return success(); 937 } 938 939 LogicalResult OpTrait::impl::verifyOneSuccessor(Operation *op) { 940 if (op->getNumSuccessors() != 1) { 941 return op->emitOpError("requires 1 successor but found ") 942 << op->getNumSuccessors(); 943 } 944 return verifyTerminatorSuccessors(op); 945 } 946 LogicalResult OpTrait::impl::verifyNSuccessors(Operation *op, 947 unsigned numSuccessors) { 948 if (op->getNumSuccessors() != numSuccessors) { 949 return op->emitOpError("requires ") 950 << numSuccessors << " successors but found " 951 << op->getNumSuccessors(); 952 } 953 return verifyTerminatorSuccessors(op); 954 } 955 LogicalResult OpTrait::impl::verifyAtLeastNSuccessors(Operation *op, 956 unsigned numSuccessors) { 957 if (op->getNumSuccessors() < numSuccessors) { 958 return op->emitOpError("requires at least ") 959 << numSuccessors << " successors but found " 960 << op->getNumSuccessors(); 961 } 962 return verifyTerminatorSuccessors(op); 963 } 964 965 LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) { 966 for (auto resultType : op->getResultTypes()) { 967 auto elementType = getTensorOrVectorElementType(resultType); 968 bool isBoolType = elementType.isInteger(1); 969 if (!isBoolType) 970 return op->emitOpError() << "requires a bool result type"; 971 } 972 973 return success(); 974 } 975 976 LogicalResult OpTrait::impl::verifyResultsAreFloatLike(Operation *op) { 977 for (auto resultType : op->getResultTypes()) 978 if (!getTensorOrVectorElementType(resultType).isa<FloatType>()) 979 return op->emitOpError() << "requires a floating point type"; 980 981 return success(); 982 } 983 984 LogicalResult 985 OpTrait::impl::verifyResultsAreSignlessIntegerLike(Operation *op) { 986 for (auto resultType : op->getResultTypes()) 987 if (!getTensorOrVectorElementType(resultType).isSignlessIntOrIndex()) 988 return op->emitOpError() << "requires an integer or index type"; 989 return success(); 990 } 991 992 static LogicalResult verifyValueSizeAttr(Operation *op, StringRef attrName, 993 bool isOperand) { 994 auto sizeAttr = op->getAttrOfType<DenseIntElementsAttr>(attrName); 995 if (!sizeAttr) 996 return op->emitOpError("requires 1D vector attribute '") << attrName << "'"; 997 998 auto sizeAttrType = sizeAttr.getType().dyn_cast<VectorType>(); 999 if (!sizeAttrType || sizeAttrType.getRank() != 1 || 1000 !sizeAttrType.getElementType().isInteger(32)) 1001 return op->emitOpError("requires 1D vector of i32 attribute '") 1002 << attrName << "'"; 1003 1004 if (llvm::any_of(sizeAttr.getIntValues(), [](const APInt &element) { 1005 return !element.isNonNegative(); 1006 })) 1007 return op->emitOpError("'") 1008 << attrName << "' attribute cannot have negative elements"; 1009 1010 size_t totalCount = std::accumulate( 1011 sizeAttr.begin(), sizeAttr.end(), 0, 1012 [](unsigned all, APInt one) { return all + one.getZExtValue(); }); 1013 1014 if (isOperand && totalCount != op->getNumOperands()) 1015 return op->emitOpError("operand count (") 1016 << op->getNumOperands() << ") does not match with the total size (" 1017 << totalCount << ") specified in attribute '" << attrName << "'"; 1018 else if (!isOperand && totalCount != op->getNumResults()) 1019 return op->emitOpError("result count (") 1020 << op->getNumResults() << ") does not match with the total size (" 1021 << totalCount << ") specified in attribute '" << attrName << "'"; 1022 return success(); 1023 } 1024 1025 LogicalResult OpTrait::impl::verifyOperandSizeAttr(Operation *op, 1026 StringRef attrName) { 1027 return verifyValueSizeAttr(op, attrName, /*isOperand=*/true); 1028 } 1029 1030 LogicalResult OpTrait::impl::verifyResultSizeAttr(Operation *op, 1031 StringRef attrName) { 1032 return verifyValueSizeAttr(op, attrName, /*isOperand=*/false); 1033 } 1034 1035 LogicalResult OpTrait::impl::verifyNoRegionArguments(Operation *op) { 1036 for (Region ®ion : op->getRegions()) { 1037 if (region.empty()) 1038 continue; 1039 1040 if (region.getNumArguments() != 0) { 1041 if (op->getNumRegions() > 1) 1042 return op->emitOpError("region #") 1043 << region.getRegionNumber() << " should have no arguments"; 1044 else 1045 return op->emitOpError("region should have no arguments"); 1046 } 1047 } 1048 return success(); 1049 } 1050 1051 LogicalResult OpTrait::impl::verifyElementwise(Operation *op) { 1052 auto isMappableType = [](Type type) { 1053 return type.isa<VectorType, TensorType>(); 1054 }; 1055 auto resultMappableTypes = llvm::to_vector<1>( 1056 llvm::make_filter_range(op->getResultTypes(), isMappableType)); 1057 auto operandMappableTypes = llvm::to_vector<2>( 1058 llvm::make_filter_range(op->getOperandTypes(), isMappableType)); 1059 1060 // If the op only has scalar operand/result types, then we have nothing to 1061 // check. 1062 if (resultMappableTypes.empty() && operandMappableTypes.empty()) 1063 return success(); 1064 1065 if (!resultMappableTypes.empty() && operandMappableTypes.empty()) 1066 return op->emitOpError("if a result is non-scalar, then at least one " 1067 "operand must be non-scalar"); 1068 1069 assert(!operandMappableTypes.empty()); 1070 1071 if (resultMappableTypes.empty()) 1072 return op->emitOpError("if an operand is non-scalar, then there must be at " 1073 "least one non-scalar result"); 1074 1075 if (resultMappableTypes.size() != op->getNumResults()) 1076 return op->emitOpError( 1077 "if an operand is non-scalar, then all results must be non-scalar"); 1078 1079 SmallVector<Type, 4> types = llvm::to_vector<2>( 1080 llvm::concat<Type>(operandMappableTypes, resultMappableTypes)); 1081 TypeID expectedBaseTy = types.front().getTypeID(); 1082 if (!llvm::all_of(types, 1083 [&](Type t) { return t.getTypeID() == expectedBaseTy; }) || 1084 failed(verifyCompatibleShapes(types))) { 1085 return op->emitOpError() << "all non-scalar operands/results must have the " 1086 "same shape and base type"; 1087 } 1088 1089 return success(); 1090 } 1091 1092 /// Check for any values used by operations regions attached to the 1093 /// specified "IsIsolatedFromAbove" operation defined outside of it. 1094 LogicalResult OpTrait::impl::verifyIsIsolatedFromAbove(Operation *isolatedOp) { 1095 assert(isolatedOp->hasTrait<OpTrait::IsIsolatedFromAbove>() && 1096 "Intended to check IsolatedFromAbove ops"); 1097 1098 // List of regions to analyze. Each region is processed independently, with 1099 // respect to the common `limit` region, so we can look at them in any order. 1100 // Therefore, use a simple vector and push/pop back the current region. 1101 SmallVector<Region *, 8> pendingRegions; 1102 for (auto ®ion : isolatedOp->getRegions()) { 1103 pendingRegions.push_back(®ion); 1104 1105 // Traverse all operations in the region. 1106 while (!pendingRegions.empty()) { 1107 for (Operation &op : pendingRegions.pop_back_val()->getOps()) { 1108 for (Value operand : op.getOperands()) { 1109 // operand should be non-null here if the IR is well-formed. But 1110 // we don't assert here as this function is called from the verifier 1111 // and so could be called on invalid IR. 1112 if (!operand) 1113 return op.emitOpError("operation's operand is null"); 1114 1115 // Check that any value that is used by an operation is defined in the 1116 // same region as either an operation result. 1117 auto *operandRegion = operand.getParentRegion(); 1118 if (!region.isAncestor(operandRegion)) { 1119 return op.emitOpError("using value defined outside the region") 1120 .attachNote(isolatedOp->getLoc()) 1121 << "required by region isolation constraints"; 1122 } 1123 } 1124 1125 // Schedule any regions in the operation for further checking. Don't 1126 // recurse into other IsolatedFromAbove ops, because they will check 1127 // themselves. 1128 if (op.getNumRegions() && 1129 !op.hasTrait<OpTrait::IsIsolatedFromAbove>()) { 1130 for (Region &subRegion : op.getRegions()) 1131 pendingRegions.push_back(&subRegion); 1132 } 1133 } 1134 } 1135 } 1136 1137 return success(); 1138 } 1139 1140 bool OpTrait::hasElementwiseMappableTraits(Operation *op) { 1141 return op->hasTrait<Elementwise>() && op->hasTrait<Scalarizable>() && 1142 op->hasTrait<Vectorizable>() && op->hasTrait<Tensorizable>(); 1143 } 1144 1145 //===----------------------------------------------------------------------===// 1146 // BinaryOp implementation 1147 //===----------------------------------------------------------------------===// 1148 1149 // These functions are out-of-line implementations of the methods in BinaryOp, 1150 // which avoids them being template instantiated/duplicated. 1151 1152 void impl::buildBinaryOp(OpBuilder &builder, OperationState &result, Value lhs, 1153 Value rhs) { 1154 assert(lhs.getType() == rhs.getType()); 1155 result.addOperands({lhs, rhs}); 1156 result.types.push_back(lhs.getType()); 1157 } 1158 1159 ParseResult impl::parseOneResultSameOperandTypeOp(OpAsmParser &parser, 1160 OperationState &result) { 1161 SmallVector<OpAsmParser::OperandType, 2> ops; 1162 Type type; 1163 return failure(parser.parseOperandList(ops) || 1164 parser.parseOptionalAttrDict(result.attributes) || 1165 parser.parseColonType(type) || 1166 parser.resolveOperands(ops, type, result.operands) || 1167 parser.addTypeToList(type, result.types)); 1168 } 1169 1170 void impl::printOneResultOp(Operation *op, OpAsmPrinter &p) { 1171 assert(op->getNumResults() == 1 && "op should have one result"); 1172 1173 // If not all the operand and result types are the same, just use the 1174 // generic assembly form to avoid omitting information in printing. 1175 auto resultType = op->getResult(0).getType(); 1176 if (llvm::any_of(op->getOperandTypes(), 1177 [&](Type type) { return type != resultType; })) { 1178 p.printGenericOp(op); 1179 return; 1180 } 1181 1182 p << op->getName() << ' '; 1183 p.printOperands(op->getOperands()); 1184 p.printOptionalAttrDict(op->getAttrs()); 1185 // Now we can output only one type for all operands and the result. 1186 p << " : " << resultType; 1187 } 1188 1189 //===----------------------------------------------------------------------===// 1190 // CastOp implementation 1191 //===----------------------------------------------------------------------===// 1192 1193 /// Attempt to fold the given cast operation. 1194 LogicalResult 1195 impl::foldCastInterfaceOp(Operation *op, ArrayRef<Attribute> attrOperands, 1196 SmallVectorImpl<OpFoldResult> &foldResults) { 1197 OperandRange operands = op->getOperands(); 1198 if (operands.empty()) 1199 return failure(); 1200 ResultRange results = op->getResults(); 1201 1202 // Check for the case where the input and output types match 1-1. 1203 if (operands.getTypes() == results.getTypes()) { 1204 foldResults.append(operands.begin(), operands.end()); 1205 return success(); 1206 } 1207 1208 return failure(); 1209 } 1210 1211 /// Attempt to verify the given cast operation. 1212 LogicalResult impl::verifyCastInterfaceOp( 1213 Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible) { 1214 auto resultTypes = op->getResultTypes(); 1215 if (llvm::empty(resultTypes)) 1216 return op->emitOpError() 1217 << "expected at least one result for cast operation"; 1218 1219 auto operandTypes = op->getOperandTypes(); 1220 if (!areCastCompatible(operandTypes, resultTypes)) { 1221 InFlightDiagnostic diag = op->emitOpError("operand type"); 1222 if (llvm::empty(operandTypes)) 1223 diag << "s []"; 1224 else if (llvm::size(operandTypes) == 1) 1225 diag << " " << *operandTypes.begin(); 1226 else 1227 diag << "s " << operandTypes; 1228 return diag << " and result type" << (resultTypes.size() == 1 ? " " : "s ") 1229 << resultTypes << " are cast incompatible"; 1230 } 1231 1232 return success(); 1233 } 1234 1235 void impl::buildCastOp(OpBuilder &builder, OperationState &result, Value source, 1236 Type destType) { 1237 result.addOperands(source); 1238 result.addTypes(destType); 1239 } 1240 1241 ParseResult impl::parseCastOp(OpAsmParser &parser, OperationState &result) { 1242 OpAsmParser::OperandType srcInfo; 1243 Type srcType, dstType; 1244 return failure(parser.parseOperand(srcInfo) || 1245 parser.parseOptionalAttrDict(result.attributes) || 1246 parser.parseColonType(srcType) || 1247 parser.resolveOperand(srcInfo, srcType, result.operands) || 1248 parser.parseKeywordType("to", dstType) || 1249 parser.addTypeToList(dstType, result.types)); 1250 } 1251 1252 void impl::printCastOp(Operation *op, OpAsmPrinter &p) { 1253 p << op->getName() << ' ' << op->getOperand(0); 1254 p.printOptionalAttrDict(op->getAttrs()); 1255 p << " : " << op->getOperand(0).getType() << " to " 1256 << op->getResult(0).getType(); 1257 } 1258 1259 Value impl::foldCastOp(Operation *op) { 1260 // Identity cast 1261 if (op->getOperand(0).getType() == op->getResult(0).getType()) 1262 return op->getOperand(0); 1263 return nullptr; 1264 } 1265 1266 LogicalResult 1267 impl::verifyCastOp(Operation *op, 1268 function_ref<bool(Type, Type)> areCastCompatible) { 1269 auto opType = op->getOperand(0).getType(); 1270 auto resType = op->getResult(0).getType(); 1271 if (!areCastCompatible(opType, resType)) 1272 return op->emitError("operand type ") 1273 << opType << " and result type " << resType 1274 << " are cast incompatible"; 1275 1276 return success(); 1277 } 1278 1279 //===----------------------------------------------------------------------===// 1280 // Misc. utils 1281 //===----------------------------------------------------------------------===// 1282 1283 /// Insert an operation, generated by `buildTerminatorOp`, at the end of the 1284 /// region's only block if it does not have a terminator already. If the region 1285 /// is empty, insert a new block first. `buildTerminatorOp` should return the 1286 /// terminator operation to insert. 1287 void impl::ensureRegionTerminator( 1288 Region ®ion, OpBuilder &builder, Location loc, 1289 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { 1290 OpBuilder::InsertionGuard guard(builder); 1291 if (region.empty()) 1292 builder.createBlock(®ion); 1293 1294 Block &block = region.back(); 1295 if (!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>()) 1296 return; 1297 1298 builder.setInsertionPointToEnd(&block); 1299 builder.insert(buildTerminatorOp(builder, loc)); 1300 } 1301 1302 /// Create a simple OpBuilder and forward to the OpBuilder version of this 1303 /// function. 1304 void impl::ensureRegionTerminator( 1305 Region ®ion, Builder &builder, Location loc, 1306 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { 1307 OpBuilder opBuilder(builder.getContext()); 1308 ensureRegionTerminator(region, opBuilder, loc, buildTerminatorOp); 1309 } 1310 1311 //===----------------------------------------------------------------------===// 1312 // UseIterator 1313 //===----------------------------------------------------------------------===// 1314 1315 Operation::UseIterator::UseIterator(Operation *op, bool end) 1316 : op(op), res(end ? op->result_end() : op->result_begin()) { 1317 // Only initialize current use if there are results/can be uses. 1318 if (op->getNumResults()) 1319 skipOverResultsWithNoUsers(); 1320 } 1321 1322 Operation::UseIterator &Operation::UseIterator::operator++() { 1323 // We increment over uses, if we reach the last use then move to next 1324 // result. 1325 if (use != (*res).use_end()) 1326 ++use; 1327 if (use == (*res).use_end()) { 1328 ++res; 1329 skipOverResultsWithNoUsers(); 1330 } 1331 return *this; 1332 } 1333 1334 void Operation::UseIterator::skipOverResultsWithNoUsers() { 1335 while (res != op->result_end() && (*res).use_empty()) 1336 ++res; 1337 1338 // If we are at the last result, then set use to first use of 1339 // first result (sentinel value used for end). 1340 if (res == op->result_end()) 1341 use = {}; 1342 else 1343 use = (*res).use_begin(); 1344 } 1345