1 //===- Operation.cpp - Operation support code -----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Operation.h" 10 #include "mlir/IR/BlockAndValueMapping.h" 11 #include "mlir/IR/BuiltinTypes.h" 12 #include "mlir/IR/Dialect.h" 13 #include "mlir/IR/OpImplementation.h" 14 #include "mlir/IR/PatternMatch.h" 15 #include "mlir/IR/TypeUtilities.h" 16 #include "mlir/Interfaces/FoldInterfaces.h" 17 #include "llvm/ADT/StringExtras.h" 18 #include <numeric> 19 20 using namespace mlir; 21 22 //===----------------------------------------------------------------------===// 23 // Operation 24 //===----------------------------------------------------------------------===// 25 26 /// Create a new Operation with the specific fields. 27 Operation *Operation::create(Location location, OperationName name, 28 TypeRange resultTypes, ValueRange operands, 29 ArrayRef<NamedAttribute> attributes, 30 BlockRange successors, unsigned numRegions) { 31 return create(location, name, resultTypes, operands, 32 DictionaryAttr::get(location.getContext(), attributes), 33 successors, numRegions); 34 } 35 36 /// Create a new Operation from operation state. 37 Operation *Operation::create(const OperationState &state) { 38 return create(state.location, state.name, state.types, state.operands, 39 state.attributes.getDictionary(state.getContext()), 40 state.successors, state.regions); 41 } 42 43 /// Create a new Operation with the specific fields. 44 Operation *Operation::create(Location location, OperationName name, 45 TypeRange resultTypes, ValueRange operands, 46 DictionaryAttr attributes, BlockRange successors, 47 RegionRange regions) { 48 unsigned numRegions = regions.size(); 49 Operation *op = create(location, name, resultTypes, operands, attributes, 50 successors, numRegions); 51 for (unsigned i = 0; i < numRegions; ++i) 52 if (regions[i]) 53 op->getRegion(i).takeBody(*regions[i]); 54 return op; 55 } 56 57 /// Overload of create that takes an existing DictionaryAttr to avoid 58 /// unnecessarily uniquing a list of attributes. 59 Operation *Operation::create(Location location, OperationName name, 60 TypeRange resultTypes, ValueRange operands, 61 DictionaryAttr attributes, BlockRange successors, 62 unsigned numRegions) { 63 assert(llvm::all_of(resultTypes, [](Type t) { return t; }) && 64 "unexpected null result type"); 65 66 // We only need to allocate additional memory for a subset of results. 67 unsigned numTrailingResults = OpResult::getNumTrailing(resultTypes.size()); 68 unsigned numInlineResults = OpResult::getNumInline(resultTypes.size()); 69 unsigned numSuccessors = successors.size(); 70 unsigned numOperands = operands.size(); 71 unsigned numResults = resultTypes.size(); 72 73 // If the operation is known to have no operands, don't allocate an operand 74 // storage. 75 bool needsOperandStorage = 76 operands.empty() ? !name.hasTrait<OpTrait::ZeroOperands>() : true; 77 78 // Compute the byte size for the operation and the operand storage. This takes 79 // into account the size of the operation, its trailing objects, and its 80 // prefixed objects. 81 size_t byteSize = 82 totalSizeToAlloc<detail::OperandStorage, BlockOperand, Region, OpOperand>( 83 needsOperandStorage ? 1 : 0, numSuccessors, numRegions, numOperands); 84 size_t prefixByteSize = llvm::alignTo( 85 Operation::prefixAllocSize(numTrailingResults, numInlineResults), 86 alignof(Operation)); 87 char *mallocMem = reinterpret_cast<char *>(malloc(byteSize + prefixByteSize)); 88 void *rawMem = mallocMem + prefixByteSize; 89 90 // Create the new Operation. 91 Operation *op = 92 ::new (rawMem) Operation(location, name, numResults, numSuccessors, 93 numRegions, attributes, needsOperandStorage); 94 95 assert((numSuccessors == 0 || op->mightHaveTrait<OpTrait::IsTerminator>()) && 96 "unexpected successors in a non-terminator operation"); 97 98 // Initialize the results. 99 auto resultTypeIt = resultTypes.begin(); 100 for (unsigned i = 0; i < numInlineResults; ++i, ++resultTypeIt) 101 new (op->getInlineOpResult(i)) detail::InlineOpResult(*resultTypeIt, i); 102 for (unsigned i = 0; i < numTrailingResults; ++i, ++resultTypeIt) { 103 new (op->getOutOfLineOpResult(i)) 104 detail::OutOfLineOpResult(*resultTypeIt, i); 105 } 106 107 // Initialize the regions. 108 for (unsigned i = 0; i != numRegions; ++i) 109 new (&op->getRegion(i)) Region(op); 110 111 // Initialize the operands. 112 if (needsOperandStorage) { 113 new (&op->getOperandStorage()) detail::OperandStorage( 114 op, op->getTrailingObjects<OpOperand>(), operands); 115 } 116 117 // Initialize the successors. 118 auto blockOperands = op->getBlockOperands(); 119 for (unsigned i = 0; i != numSuccessors; ++i) 120 new (&blockOperands[i]) BlockOperand(op, successors[i]); 121 122 return op; 123 } 124 125 Operation::Operation(Location location, OperationName name, unsigned numResults, 126 unsigned numSuccessors, unsigned numRegions, 127 DictionaryAttr attributes, bool hasOperandStorage) 128 : location(location), numResults(numResults), numSuccs(numSuccessors), 129 numRegions(numRegions), hasOperandStorage(hasOperandStorage), name(name), 130 attrs(attributes) { 131 assert(attributes && "unexpected null attribute dictionary"); 132 #ifndef NDEBUG 133 if (!getDialect() && !getContext()->allowsUnregisteredDialects()) 134 llvm::report_fatal_error( 135 name.getStringRef() + 136 " created with unregistered dialect. If this is intended, please call " 137 "allowUnregisteredDialects() on the MLIRContext, or use " 138 "-allow-unregistered-dialect with the MLIR tool used."); 139 #endif 140 } 141 142 // Operations are deleted through the destroy() member because they are 143 // allocated via malloc. 144 Operation::~Operation() { 145 assert(block == nullptr && "operation destroyed but still in a block"); 146 #ifndef NDEBUG 147 if (!use_empty()) { 148 { 149 InFlightDiagnostic diag = 150 emitOpError("operation destroyed but still has uses"); 151 for (Operation *user : getUsers()) 152 diag.attachNote(user->getLoc()) << "- use: " << *user << "\n"; 153 } 154 llvm::report_fatal_error("operation destroyed but still has uses"); 155 } 156 #endif 157 // Explicitly run the destructors for the operands. 158 if (hasOperandStorage) 159 getOperandStorage().~OperandStorage(); 160 161 // Explicitly run the destructors for the successors. 162 for (auto &successor : getBlockOperands()) 163 successor.~BlockOperand(); 164 165 // Explicitly destroy the regions. 166 for (auto ®ion : getRegions()) 167 region.~Region(); 168 } 169 170 /// Destroy this operation or one of its subclasses. 171 void Operation::destroy() { 172 // Operations may have additional prefixed allocation, which needs to be 173 // accounted for here when computing the address to free. 174 char *rawMem = reinterpret_cast<char *>(this) - 175 llvm::alignTo(prefixAllocSize(), alignof(Operation)); 176 this->~Operation(); 177 free(rawMem); 178 } 179 180 /// Return true if this operation is a proper ancestor of the `other` 181 /// operation. 182 bool Operation::isProperAncestor(Operation *other) { 183 while ((other = other->getParentOp())) 184 if (this == other) 185 return true; 186 return false; 187 } 188 189 /// Replace any uses of 'from' with 'to' within this operation. 190 void Operation::replaceUsesOfWith(Value from, Value to) { 191 if (from == to) 192 return; 193 for (auto &operand : getOpOperands()) 194 if (operand.get() == from) 195 operand.set(to); 196 } 197 198 /// Replace the current operands of this operation with the ones provided in 199 /// 'operands'. 200 void Operation::setOperands(ValueRange operands) { 201 if (LLVM_LIKELY(hasOperandStorage)) 202 return getOperandStorage().setOperands(this, operands); 203 assert(operands.empty() && "setting operands without an operand storage"); 204 } 205 206 /// Replace the operands beginning at 'start' and ending at 'start' + 'length' 207 /// with the ones provided in 'operands'. 'operands' may be smaller or larger 208 /// than the range pointed to by 'start'+'length'. 209 void Operation::setOperands(unsigned start, unsigned length, 210 ValueRange operands) { 211 assert((start + length) <= getNumOperands() && 212 "invalid operand range specified"); 213 if (LLVM_LIKELY(hasOperandStorage)) 214 return getOperandStorage().setOperands(this, start, length, operands); 215 assert(operands.empty() && "setting operands without an operand storage"); 216 } 217 218 /// Insert the given operands into the operand list at the given 'index'. 219 void Operation::insertOperands(unsigned index, ValueRange operands) { 220 if (LLVM_LIKELY(hasOperandStorage)) 221 return setOperands(index, /*length=*/0, operands); 222 assert(operands.empty() && "inserting operands without an operand storage"); 223 } 224 225 //===----------------------------------------------------------------------===// 226 // Diagnostics 227 //===----------------------------------------------------------------------===// 228 229 /// Emit an error about fatal conditions with this operation, reporting up to 230 /// any diagnostic handlers that may be listening. 231 InFlightDiagnostic Operation::emitError(const Twine &message) { 232 InFlightDiagnostic diag = mlir::emitError(getLoc(), message); 233 if (getContext()->shouldPrintOpOnDiagnostic()) { 234 diag.attachNote(getLoc()) 235 .append("see current operation: ") 236 .appendOp(*this, OpPrintingFlags().printGenericOpForm()); 237 } 238 return diag; 239 } 240 241 /// Emit a warning about this operation, reporting up to any diagnostic 242 /// handlers that may be listening. 243 InFlightDiagnostic Operation::emitWarning(const Twine &message) { 244 InFlightDiagnostic diag = mlir::emitWarning(getLoc(), message); 245 if (getContext()->shouldPrintOpOnDiagnostic()) 246 diag.attachNote(getLoc()) << "see current operation: " << *this; 247 return diag; 248 } 249 250 /// Emit a remark about this operation, reporting up to any diagnostic 251 /// handlers that may be listening. 252 InFlightDiagnostic Operation::emitRemark(const Twine &message) { 253 InFlightDiagnostic diag = mlir::emitRemark(getLoc(), message); 254 if (getContext()->shouldPrintOpOnDiagnostic()) 255 diag.attachNote(getLoc()) << "see current operation: " << *this; 256 return diag; 257 } 258 259 //===----------------------------------------------------------------------===// 260 // Operation Ordering 261 //===----------------------------------------------------------------------===// 262 263 constexpr unsigned Operation::kInvalidOrderIdx; 264 constexpr unsigned Operation::kOrderStride; 265 266 /// Given an operation 'other' that is within the same parent block, return 267 /// whether the current operation is before 'other' in the operation list 268 /// of the parent block. 269 /// Note: This function has an average complexity of O(1), but worst case may 270 /// take O(N) where N is the number of operations within the parent block. 271 bool Operation::isBeforeInBlock(Operation *other) { 272 assert(block && "Operations without parent blocks have no order."); 273 assert(other && other->block == block && 274 "Expected other operation to have the same parent block."); 275 // If the order of the block is already invalid, directly recompute the 276 // parent. 277 if (!block->isOpOrderValid()) { 278 block->recomputeOpOrder(); 279 } else { 280 // Update the order either operation if necessary. 281 updateOrderIfNecessary(); 282 other->updateOrderIfNecessary(); 283 } 284 285 return orderIndex < other->orderIndex; 286 } 287 288 /// Update the order index of this operation of this operation if necessary, 289 /// potentially recomputing the order of the parent block. 290 void Operation::updateOrderIfNecessary() { 291 assert(block && "expected valid parent"); 292 293 // If the order is valid for this operation there is nothing to do. 294 if (hasValidOrder()) 295 return; 296 Operation *blockFront = &block->front(); 297 Operation *blockBack = &block->back(); 298 299 // This method is expected to only be invoked on blocks with more than one 300 // operation. 301 assert(blockFront != blockBack && "expected more than one operation"); 302 303 // If the operation is at the end of the block. 304 if (this == blockBack) { 305 Operation *prevNode = getPrevNode(); 306 if (!prevNode->hasValidOrder()) 307 return block->recomputeOpOrder(); 308 309 // Add the stride to the previous operation. 310 orderIndex = prevNode->orderIndex + kOrderStride; 311 return; 312 } 313 314 // If this is the first operation try to use the next operation to compute the 315 // ordering. 316 if (this == blockFront) { 317 Operation *nextNode = getNextNode(); 318 if (!nextNode->hasValidOrder()) 319 return block->recomputeOpOrder(); 320 // There is no order to give this operation. 321 if (nextNode->orderIndex == 0) 322 return block->recomputeOpOrder(); 323 324 // If we can't use the stride, just take the middle value left. This is safe 325 // because we know there is at least one valid index to assign to. 326 if (nextNode->orderIndex <= kOrderStride) 327 orderIndex = (nextNode->orderIndex / 2); 328 else 329 orderIndex = kOrderStride; 330 return; 331 } 332 333 // Otherwise, this operation is between two others. Place this operation in 334 // the middle of the previous and next if possible. 335 Operation *prevNode = getPrevNode(), *nextNode = getNextNode(); 336 if (!prevNode->hasValidOrder() || !nextNode->hasValidOrder()) 337 return block->recomputeOpOrder(); 338 unsigned prevOrder = prevNode->orderIndex, nextOrder = nextNode->orderIndex; 339 340 // Check to see if there is a valid order between the two. 341 if (prevOrder + 1 == nextOrder) 342 return block->recomputeOpOrder(); 343 orderIndex = prevOrder + ((nextOrder - prevOrder) / 2); 344 } 345 346 //===----------------------------------------------------------------------===// 347 // ilist_traits for Operation 348 //===----------------------------------------------------------------------===// 349 350 auto llvm::ilist_detail::SpecificNodeAccess< 351 typename llvm::ilist_detail::compute_node_options< 352 ::mlir::Operation>::type>::getNodePtr(pointer N) -> node_type * { 353 return NodeAccess::getNodePtr<OptionsT>(N); 354 } 355 356 auto llvm::ilist_detail::SpecificNodeAccess< 357 typename llvm::ilist_detail::compute_node_options< 358 ::mlir::Operation>::type>::getNodePtr(const_pointer N) 359 -> const node_type * { 360 return NodeAccess::getNodePtr<OptionsT>(N); 361 } 362 363 auto llvm::ilist_detail::SpecificNodeAccess< 364 typename llvm::ilist_detail::compute_node_options< 365 ::mlir::Operation>::type>::getValuePtr(node_type *N) -> pointer { 366 return NodeAccess::getValuePtr<OptionsT>(N); 367 } 368 369 auto llvm::ilist_detail::SpecificNodeAccess< 370 typename llvm::ilist_detail::compute_node_options< 371 ::mlir::Operation>::type>::getValuePtr(const node_type *N) 372 -> const_pointer { 373 return NodeAccess::getValuePtr<OptionsT>(N); 374 } 375 376 void llvm::ilist_traits<::mlir::Operation>::deleteNode(Operation *op) { 377 op->destroy(); 378 } 379 380 Block *llvm::ilist_traits<::mlir::Operation>::getContainingBlock() { 381 size_t Offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr)))); 382 iplist<Operation> *Anchor(static_cast<iplist<Operation> *>(this)); 383 return reinterpret_cast<Block *>(reinterpret_cast<char *>(Anchor) - Offset); 384 } 385 386 /// This is a trait method invoked when an operation is added to a block. We 387 /// keep the block pointer up to date. 388 void llvm::ilist_traits<::mlir::Operation>::addNodeToList(Operation *op) { 389 assert(!op->getBlock() && "already in an operation block!"); 390 op->block = getContainingBlock(); 391 392 // Invalidate the order on the operation. 393 op->orderIndex = Operation::kInvalidOrderIdx; 394 } 395 396 /// This is a trait method invoked when an operation is removed from a block. 397 /// We keep the block pointer up to date. 398 void llvm::ilist_traits<::mlir::Operation>::removeNodeFromList(Operation *op) { 399 assert(op->block && "not already in an operation block!"); 400 op->block = nullptr; 401 } 402 403 /// This is a trait method invoked when an operation is moved from one block 404 /// to another. We keep the block pointer up to date. 405 void llvm::ilist_traits<::mlir::Operation>::transferNodesFromList( 406 ilist_traits<Operation> &otherList, op_iterator first, op_iterator last) { 407 Block *curParent = getContainingBlock(); 408 409 // Invalidate the ordering of the parent block. 410 curParent->invalidateOpOrder(); 411 412 // If we are transferring operations within the same block, the block 413 // pointer doesn't need to be updated. 414 if (curParent == otherList.getContainingBlock()) 415 return; 416 417 // Update the 'block' member of each operation. 418 for (; first != last; ++first) 419 first->block = curParent; 420 } 421 422 /// Remove this operation (and its descendants) from its Block and delete 423 /// all of them. 424 void Operation::erase() { 425 if (auto *parent = getBlock()) 426 parent->getOperations().erase(this); 427 else 428 destroy(); 429 } 430 431 /// Remove the operation from its parent block, but don't delete it. 432 void Operation::remove() { 433 if (Block *parent = getBlock()) 434 parent->getOperations().remove(this); 435 } 436 437 /// Unlink this operation from its current block and insert it right before 438 /// `existingOp` which may be in the same or another block in the same 439 /// function. 440 void Operation::moveBefore(Operation *existingOp) { 441 moveBefore(existingOp->getBlock(), existingOp->getIterator()); 442 } 443 444 /// Unlink this operation from its current basic block and insert it right 445 /// before `iterator` in the specified basic block. 446 void Operation::moveBefore(Block *block, 447 llvm::iplist<Operation>::iterator iterator) { 448 block->getOperations().splice(iterator, getBlock()->getOperations(), 449 getIterator()); 450 } 451 452 /// Unlink this operation from its current block and insert it right after 453 /// `existingOp` which may be in the same or another block in the same function. 454 void Operation::moveAfter(Operation *existingOp) { 455 moveAfter(existingOp->getBlock(), existingOp->getIterator()); 456 } 457 458 /// Unlink this operation from its current block and insert it right after 459 /// `iterator` in the specified block. 460 void Operation::moveAfter(Block *block, 461 llvm::iplist<Operation>::iterator iterator) { 462 assert(iterator != block->end() && "cannot move after end of block"); 463 moveBefore(block, std::next(iterator)); 464 } 465 466 /// This drops all operand uses from this operation, which is an essential 467 /// step in breaking cyclic dependences between references when they are to 468 /// be deleted. 469 void Operation::dropAllReferences() { 470 for (auto &op : getOpOperands()) 471 op.drop(); 472 473 for (auto ®ion : getRegions()) 474 region.dropAllReferences(); 475 476 for (auto &dest : getBlockOperands()) 477 dest.drop(); 478 } 479 480 /// This drops all uses of any values defined by this operation or its nested 481 /// regions, wherever they are located. 482 void Operation::dropAllDefinedValueUses() { 483 dropAllUses(); 484 485 for (auto ®ion : getRegions()) 486 for (auto &block : region) 487 block.dropAllDefinedValueUses(); 488 } 489 490 void Operation::setSuccessor(Block *block, unsigned index) { 491 assert(index < getNumSuccessors()); 492 getBlockOperands()[index].set(block); 493 } 494 495 /// Attempt to fold this operation using the Op's registered foldHook. 496 LogicalResult Operation::fold(ArrayRef<Attribute> operands, 497 SmallVectorImpl<OpFoldResult> &results) { 498 // If we have a registered operation definition matching this one, use it to 499 // try to constant fold the operation. 500 Optional<RegisteredOperationName> info = getRegisteredInfo(); 501 if (info && succeeded(info->foldHook(this, operands, results))) 502 return success(); 503 504 // Otherwise, fall back on the dialect hook to handle it. 505 Dialect *dialect = getDialect(); 506 if (!dialect) 507 return failure(); 508 509 auto *interface = dialect->getRegisteredInterface<DialectFoldInterface>(); 510 if (!interface) 511 return failure(); 512 513 return interface->fold(this, operands, results); 514 } 515 516 /// Emit an error with the op name prefixed, like "'dim' op " which is 517 /// convenient for verifiers. 518 InFlightDiagnostic Operation::emitOpError(const Twine &message) { 519 return emitError() << "'" << getName() << "' op " << message; 520 } 521 522 //===----------------------------------------------------------------------===// 523 // Operation Cloning 524 //===----------------------------------------------------------------------===// 525 526 /// Create a deep copy of this operation but keep the operation regions empty. 527 /// Operands are remapped using `mapper` (if present), and `mapper` is updated 528 /// to contain the results. 529 Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper) { 530 SmallVector<Value, 8> operands; 531 SmallVector<Block *, 2> successors; 532 533 // Remap the operands. 534 operands.reserve(getNumOperands()); 535 for (auto opValue : getOperands()) 536 operands.push_back(mapper.lookupOrDefault(opValue)); 537 538 // Remap the successors. 539 successors.reserve(getNumSuccessors()); 540 for (Block *successor : getSuccessors()) 541 successors.push_back(mapper.lookupOrDefault(successor)); 542 543 // Create the new operation. 544 auto *newOp = create(getLoc(), getName(), getResultTypes(), operands, attrs, 545 successors, getNumRegions()); 546 547 // Remember the mapping of any results. 548 for (unsigned i = 0, e = getNumResults(); i != e; ++i) 549 mapper.map(getResult(i), newOp->getResult(i)); 550 551 return newOp; 552 } 553 554 Operation *Operation::cloneWithoutRegions() { 555 BlockAndValueMapping mapper; 556 return cloneWithoutRegions(mapper); 557 } 558 559 /// Create a deep copy of this operation, remapping any operands that use 560 /// values outside of the operation using the map that is provided (leaving 561 /// them alone if no entry is present). Replaces references to cloned 562 /// sub-operations to the corresponding operation that is copied, and adds 563 /// those mappings to the map. 564 Operation *Operation::clone(BlockAndValueMapping &mapper) { 565 auto *newOp = cloneWithoutRegions(mapper); 566 567 // Clone the regions. 568 for (unsigned i = 0; i != numRegions; ++i) 569 getRegion(i).cloneInto(&newOp->getRegion(i), mapper); 570 571 return newOp; 572 } 573 574 Operation *Operation::clone() { 575 BlockAndValueMapping mapper; 576 return clone(mapper); 577 } 578 579 //===----------------------------------------------------------------------===// 580 // OpState trait class. 581 //===----------------------------------------------------------------------===// 582 583 // The fallback for the parser is to reject the custom assembly form. 584 ParseResult OpState::parse(OpAsmParser &parser, OperationState &result) { 585 return parser.emitError(parser.getNameLoc(), "has no custom assembly form"); 586 } 587 588 // The fallback for the printer is to print in the generic assembly form. 589 void OpState::print(Operation *op, OpAsmPrinter &p) { p.printGenericOp(op); } 590 // The fallback for the printer is to print in the generic assembly form. 591 void OpState::printOpName(Operation *op, OpAsmPrinter &p, 592 StringRef defaultDialect) { 593 StringRef name = op->getName().getStringRef(); 594 if (name.startswith((defaultDialect + ".").str())) 595 name = name.drop_front(defaultDialect.size() + 1); 596 // TODO: remove this special case (and update test/IR/parser.mlir) 597 else if ((defaultDialect.empty() || defaultDialect == "builtin") && 598 name.startswith("std.")) 599 name = name.drop_front(4); 600 p.getStream() << name; 601 } 602 603 /// Emit an error about fatal conditions with this operation, reporting up to 604 /// any diagnostic handlers that may be listening. 605 InFlightDiagnostic OpState::emitError(const Twine &message) { 606 return getOperation()->emitError(message); 607 } 608 609 /// Emit an error with the op name prefixed, like "'dim' op " which is 610 /// convenient for verifiers. 611 InFlightDiagnostic OpState::emitOpError(const Twine &message) { 612 return getOperation()->emitOpError(message); 613 } 614 615 /// Emit a warning about this operation, reporting up to any diagnostic 616 /// handlers that may be listening. 617 InFlightDiagnostic OpState::emitWarning(const Twine &message) { 618 return getOperation()->emitWarning(message); 619 } 620 621 /// Emit a remark about this operation, reporting up to any diagnostic 622 /// handlers that may be listening. 623 InFlightDiagnostic OpState::emitRemark(const Twine &message) { 624 return getOperation()->emitRemark(message); 625 } 626 627 //===----------------------------------------------------------------------===// 628 // Op Trait implementations 629 //===----------------------------------------------------------------------===// 630 631 OpFoldResult OpTrait::impl::foldIdempotent(Operation *op) { 632 if (op->getNumOperands() == 1) { 633 auto *argumentOp = op->getOperand(0).getDefiningOp(); 634 if (argumentOp && op->getName() == argumentOp->getName()) { 635 // Replace the outer operation output with the inner operation. 636 return op->getOperand(0); 637 } 638 } else if (op->getOperand(0) == op->getOperand(1)) { 639 return op->getOperand(0); 640 } 641 642 return {}; 643 } 644 645 OpFoldResult OpTrait::impl::foldInvolution(Operation *op) { 646 auto *argumentOp = op->getOperand(0).getDefiningOp(); 647 if (argumentOp && op->getName() == argumentOp->getName()) { 648 // Replace the outer involutions output with inner's input. 649 return argumentOp->getOperand(0); 650 } 651 652 return {}; 653 } 654 655 LogicalResult OpTrait::impl::verifyZeroOperands(Operation *op) { 656 if (op->getNumOperands() != 0) 657 return op->emitOpError() << "requires zero operands"; 658 return success(); 659 } 660 661 LogicalResult OpTrait::impl::verifyOneOperand(Operation *op) { 662 if (op->getNumOperands() != 1) 663 return op->emitOpError() << "requires a single operand"; 664 return success(); 665 } 666 667 LogicalResult OpTrait::impl::verifyNOperands(Operation *op, 668 unsigned numOperands) { 669 if (op->getNumOperands() != numOperands) { 670 return op->emitOpError() << "expected " << numOperands 671 << " operands, but found " << op->getNumOperands(); 672 } 673 return success(); 674 } 675 676 LogicalResult OpTrait::impl::verifyAtLeastNOperands(Operation *op, 677 unsigned numOperands) { 678 if (op->getNumOperands() < numOperands) 679 return op->emitOpError() 680 << "expected " << numOperands << " or more operands, but found " 681 << op->getNumOperands(); 682 return success(); 683 } 684 685 /// If this is a vector type, or a tensor type, return the scalar element type 686 /// that it is built around, otherwise return the type unmodified. 687 static Type getTensorOrVectorElementType(Type type) { 688 if (auto vec = type.dyn_cast<VectorType>()) 689 return vec.getElementType(); 690 691 // Look through tensor<vector<...>> to find the underlying element type. 692 if (auto tensor = type.dyn_cast<TensorType>()) 693 return getTensorOrVectorElementType(tensor.getElementType()); 694 return type; 695 } 696 697 LogicalResult OpTrait::impl::verifyIsIdempotent(Operation *op) { 698 // FIXME: Add back check for no side effects on operation. 699 // Currently adding it would cause the shared library build 700 // to fail since there would be a dependency of IR on SideEffectInterfaces 701 // which is cyclical. 702 return success(); 703 } 704 705 LogicalResult OpTrait::impl::verifyIsInvolution(Operation *op) { 706 // FIXME: Add back check for no side effects on operation. 707 // Currently adding it would cause the shared library build 708 // to fail since there would be a dependency of IR on SideEffectInterfaces 709 // which is cyclical. 710 return success(); 711 } 712 713 LogicalResult 714 OpTrait::impl::verifyOperandsAreSignlessIntegerLike(Operation *op) { 715 for (auto opType : op->getOperandTypes()) { 716 auto type = getTensorOrVectorElementType(opType); 717 if (!type.isSignlessIntOrIndex()) 718 return op->emitOpError() << "requires an integer or index type"; 719 } 720 return success(); 721 } 722 723 LogicalResult OpTrait::impl::verifyOperandsAreFloatLike(Operation *op) { 724 for (auto opType : op->getOperandTypes()) { 725 auto type = getTensorOrVectorElementType(opType); 726 if (!type.isa<FloatType>()) 727 return op->emitOpError("requires a float type"); 728 } 729 return success(); 730 } 731 732 LogicalResult OpTrait::impl::verifySameTypeOperands(Operation *op) { 733 // Zero or one operand always have the "same" type. 734 unsigned nOperands = op->getNumOperands(); 735 if (nOperands < 2) 736 return success(); 737 738 auto type = op->getOperand(0).getType(); 739 for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) 740 if (opType != type) 741 return op->emitOpError() << "requires all operands to have the same type"; 742 return success(); 743 } 744 745 LogicalResult OpTrait::impl::verifyZeroRegion(Operation *op) { 746 if (op->getNumRegions() != 0) 747 return op->emitOpError() << "requires zero regions"; 748 return success(); 749 } 750 751 LogicalResult OpTrait::impl::verifyOneRegion(Operation *op) { 752 if (op->getNumRegions() != 1) 753 return op->emitOpError() << "requires one region"; 754 return success(); 755 } 756 757 LogicalResult OpTrait::impl::verifyNRegions(Operation *op, 758 unsigned numRegions) { 759 if (op->getNumRegions() != numRegions) 760 return op->emitOpError() << "expected " << numRegions << " regions"; 761 return success(); 762 } 763 764 LogicalResult OpTrait::impl::verifyAtLeastNRegions(Operation *op, 765 unsigned numRegions) { 766 if (op->getNumRegions() < numRegions) 767 return op->emitOpError() << "expected " << numRegions << " or more regions"; 768 return success(); 769 } 770 771 LogicalResult OpTrait::impl::verifyZeroResult(Operation *op) { 772 if (op->getNumResults() != 0) 773 return op->emitOpError() << "requires zero results"; 774 return success(); 775 } 776 777 LogicalResult OpTrait::impl::verifyOneResult(Operation *op) { 778 if (op->getNumResults() != 1) 779 return op->emitOpError() << "requires one result"; 780 return success(); 781 } 782 783 LogicalResult OpTrait::impl::verifyNResults(Operation *op, 784 unsigned numOperands) { 785 if (op->getNumResults() != numOperands) 786 return op->emitOpError() << "expected " << numOperands << " results"; 787 return success(); 788 } 789 790 LogicalResult OpTrait::impl::verifyAtLeastNResults(Operation *op, 791 unsigned numOperands) { 792 if (op->getNumResults() < numOperands) 793 return op->emitOpError() 794 << "expected " << numOperands << " or more results"; 795 return success(); 796 } 797 798 LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) { 799 if (failed(verifyAtLeastNOperands(op, 1))) 800 return failure(); 801 802 if (failed(verifyCompatibleShapes(op->getOperandTypes()))) 803 return op->emitOpError() << "requires the same shape for all operands"; 804 805 return success(); 806 } 807 808 LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) { 809 if (failed(verifyAtLeastNOperands(op, 1)) || 810 failed(verifyAtLeastNResults(op, 1))) 811 return failure(); 812 813 SmallVector<Type, 8> types(op->getOperandTypes()); 814 types.append(llvm::to_vector<4>(op->getResultTypes())); 815 816 if (failed(verifyCompatibleShapes(types))) 817 return op->emitOpError() 818 << "requires the same shape for all operands and results"; 819 820 return success(); 821 } 822 823 LogicalResult OpTrait::impl::verifySameOperandsElementType(Operation *op) { 824 if (failed(verifyAtLeastNOperands(op, 1))) 825 return failure(); 826 auto elementType = getElementTypeOrSelf(op->getOperand(0)); 827 828 for (auto operand : llvm::drop_begin(op->getOperands(), 1)) { 829 if (getElementTypeOrSelf(operand) != elementType) 830 return op->emitOpError("requires the same element type for all operands"); 831 } 832 833 return success(); 834 } 835 836 LogicalResult 837 OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) { 838 if (failed(verifyAtLeastNOperands(op, 1)) || 839 failed(verifyAtLeastNResults(op, 1))) 840 return failure(); 841 842 auto elementType = getElementTypeOrSelf(op->getResult(0)); 843 844 // Verify result element type matches first result's element type. 845 for (auto result : llvm::drop_begin(op->getResults(), 1)) { 846 if (getElementTypeOrSelf(result) != elementType) 847 return op->emitOpError( 848 "requires the same element type for all operands and results"); 849 } 850 851 // Verify operand's element type matches first result's element type. 852 for (auto operand : op->getOperands()) { 853 if (getElementTypeOrSelf(operand) != elementType) 854 return op->emitOpError( 855 "requires the same element type for all operands and results"); 856 } 857 858 return success(); 859 } 860 861 LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) { 862 if (failed(verifyAtLeastNOperands(op, 1)) || 863 failed(verifyAtLeastNResults(op, 1))) 864 return failure(); 865 866 auto type = op->getResult(0).getType(); 867 auto elementType = getElementTypeOrSelf(type); 868 for (auto resultType : llvm::drop_begin(op->getResultTypes())) { 869 if (getElementTypeOrSelf(resultType) != elementType || 870 failed(verifyCompatibleShape(resultType, type))) 871 return op->emitOpError() 872 << "requires the same type for all operands and results"; 873 } 874 for (auto opType : op->getOperandTypes()) { 875 if (getElementTypeOrSelf(opType) != elementType || 876 failed(verifyCompatibleShape(opType, type))) 877 return op->emitOpError() 878 << "requires the same type for all operands and results"; 879 } 880 return success(); 881 } 882 883 LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) { 884 Block *block = op->getBlock(); 885 // Verify that the operation is at the end of the respective parent block. 886 if (!block || &block->back() != op) 887 return op->emitOpError("must be the last operation in the parent block"); 888 return success(); 889 } 890 891 static LogicalResult verifyTerminatorSuccessors(Operation *op) { 892 auto *parent = op->getParentRegion(); 893 894 // Verify that the operands lines up with the BB arguments in the successor. 895 for (Block *succ : op->getSuccessors()) 896 if (succ->getParent() != parent) 897 return op->emitError("reference to block defined in another region"); 898 return success(); 899 } 900 901 LogicalResult OpTrait::impl::verifyZeroSuccessor(Operation *op) { 902 if (op->getNumSuccessors() != 0) { 903 return op->emitOpError("requires 0 successors but found ") 904 << op->getNumSuccessors(); 905 } 906 return success(); 907 } 908 909 LogicalResult OpTrait::impl::verifyOneSuccessor(Operation *op) { 910 if (op->getNumSuccessors() != 1) { 911 return op->emitOpError("requires 1 successor but found ") 912 << op->getNumSuccessors(); 913 } 914 return verifyTerminatorSuccessors(op); 915 } 916 LogicalResult OpTrait::impl::verifyNSuccessors(Operation *op, 917 unsigned numSuccessors) { 918 if (op->getNumSuccessors() != numSuccessors) { 919 return op->emitOpError("requires ") 920 << numSuccessors << " successors but found " 921 << op->getNumSuccessors(); 922 } 923 return verifyTerminatorSuccessors(op); 924 } 925 LogicalResult OpTrait::impl::verifyAtLeastNSuccessors(Operation *op, 926 unsigned numSuccessors) { 927 if (op->getNumSuccessors() < numSuccessors) { 928 return op->emitOpError("requires at least ") 929 << numSuccessors << " successors but found " 930 << op->getNumSuccessors(); 931 } 932 return verifyTerminatorSuccessors(op); 933 } 934 935 LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) { 936 for (auto resultType : op->getResultTypes()) { 937 auto elementType = getTensorOrVectorElementType(resultType); 938 bool isBoolType = elementType.isInteger(1); 939 if (!isBoolType) 940 return op->emitOpError() << "requires a bool result type"; 941 } 942 943 return success(); 944 } 945 946 LogicalResult OpTrait::impl::verifyResultsAreFloatLike(Operation *op) { 947 for (auto resultType : op->getResultTypes()) 948 if (!getTensorOrVectorElementType(resultType).isa<FloatType>()) 949 return op->emitOpError() << "requires a floating point type"; 950 951 return success(); 952 } 953 954 LogicalResult 955 OpTrait::impl::verifyResultsAreSignlessIntegerLike(Operation *op) { 956 for (auto resultType : op->getResultTypes()) 957 if (!getTensorOrVectorElementType(resultType).isSignlessIntOrIndex()) 958 return op->emitOpError() << "requires an integer or index type"; 959 return success(); 960 } 961 962 LogicalResult OpTrait::impl::verifyValueSizeAttr(Operation *op, 963 StringRef attrName, 964 StringRef valueGroupName, 965 size_t expectedCount) { 966 auto sizeAttr = op->getAttrOfType<DenseIntElementsAttr>(attrName); 967 if (!sizeAttr) 968 return op->emitOpError("requires 1D i32 elements attribute '") 969 << attrName << "'"; 970 971 auto sizeAttrType = sizeAttr.getType(); 972 if (sizeAttrType.getRank() != 1 || 973 !sizeAttrType.getElementType().isInteger(32)) 974 return op->emitOpError("requires 1D i32 elements attribute '") 975 << attrName << "'"; 976 977 if (llvm::any_of(sizeAttr.getValues<APInt>(), [](const APInt &element) { 978 return !element.isNonNegative(); 979 })) 980 return op->emitOpError("'") 981 << attrName << "' attribute cannot have negative elements"; 982 983 size_t totalCount = std::accumulate( 984 sizeAttr.begin(), sizeAttr.end(), 0, 985 [](unsigned all, APInt one) { return all + one.getZExtValue(); }); 986 987 if (totalCount != expectedCount) 988 return op->emitOpError() 989 << valueGroupName << " count (" << expectedCount 990 << ") does not match with the total size (" << totalCount 991 << ") specified in attribute '" << attrName << "'"; 992 return success(); 993 } 994 995 LogicalResult OpTrait::impl::verifyOperandSizeAttr(Operation *op, 996 StringRef attrName) { 997 return verifyValueSizeAttr(op, attrName, "operand", op->getNumOperands()); 998 } 999 1000 LogicalResult OpTrait::impl::verifyResultSizeAttr(Operation *op, 1001 StringRef attrName) { 1002 return verifyValueSizeAttr(op, attrName, "result", op->getNumResults()); 1003 } 1004 1005 LogicalResult OpTrait::impl::verifyNoRegionArguments(Operation *op) { 1006 for (Region ®ion : op->getRegions()) { 1007 if (region.empty()) 1008 continue; 1009 1010 if (region.getNumArguments() != 0) { 1011 if (op->getNumRegions() > 1) 1012 return op->emitOpError("region #") 1013 << region.getRegionNumber() << " should have no arguments"; 1014 else 1015 return op->emitOpError("region should have no arguments"); 1016 } 1017 } 1018 return success(); 1019 } 1020 1021 LogicalResult OpTrait::impl::verifyElementwise(Operation *op) { 1022 auto isMappableType = [](Type type) { 1023 return type.isa<VectorType, TensorType>(); 1024 }; 1025 auto resultMappableTypes = llvm::to_vector<1>( 1026 llvm::make_filter_range(op->getResultTypes(), isMappableType)); 1027 auto operandMappableTypes = llvm::to_vector<2>( 1028 llvm::make_filter_range(op->getOperandTypes(), isMappableType)); 1029 1030 // If the op only has scalar operand/result types, then we have nothing to 1031 // check. 1032 if (resultMappableTypes.empty() && operandMappableTypes.empty()) 1033 return success(); 1034 1035 if (!resultMappableTypes.empty() && operandMappableTypes.empty()) 1036 return op->emitOpError("if a result is non-scalar, then at least one " 1037 "operand must be non-scalar"); 1038 1039 assert(!operandMappableTypes.empty()); 1040 1041 if (resultMappableTypes.empty()) 1042 return op->emitOpError("if an operand is non-scalar, then there must be at " 1043 "least one non-scalar result"); 1044 1045 if (resultMappableTypes.size() != op->getNumResults()) 1046 return op->emitOpError( 1047 "if an operand is non-scalar, then all results must be non-scalar"); 1048 1049 SmallVector<Type, 4> types = llvm::to_vector<2>( 1050 llvm::concat<Type>(operandMappableTypes, resultMappableTypes)); 1051 TypeID expectedBaseTy = types.front().getTypeID(); 1052 if (!llvm::all_of(types, 1053 [&](Type t) { return t.getTypeID() == expectedBaseTy; }) || 1054 failed(verifyCompatibleShapes(types))) { 1055 return op->emitOpError() << "all non-scalar operands/results must have the " 1056 "same shape and base type"; 1057 } 1058 1059 return success(); 1060 } 1061 1062 /// Check for any values used by operations regions attached to the 1063 /// specified "IsIsolatedFromAbove" operation defined outside of it. 1064 LogicalResult OpTrait::impl::verifyIsIsolatedFromAbove(Operation *isolatedOp) { 1065 assert(isolatedOp->hasTrait<OpTrait::IsIsolatedFromAbove>() && 1066 "Intended to check IsolatedFromAbove ops"); 1067 1068 // List of regions to analyze. Each region is processed independently, with 1069 // respect to the common `limit` region, so we can look at them in any order. 1070 // Therefore, use a simple vector and push/pop back the current region. 1071 SmallVector<Region *, 8> pendingRegions; 1072 for (auto ®ion : isolatedOp->getRegions()) { 1073 pendingRegions.push_back(®ion); 1074 1075 // Traverse all operations in the region. 1076 while (!pendingRegions.empty()) { 1077 for (Operation &op : pendingRegions.pop_back_val()->getOps()) { 1078 for (Value operand : op.getOperands()) { 1079 // operand should be non-null here if the IR is well-formed. But 1080 // we don't assert here as this function is called from the verifier 1081 // and so could be called on invalid IR. 1082 if (!operand) 1083 return op.emitOpError("operation's operand is null"); 1084 1085 // Check that any value that is used by an operation is defined in the 1086 // same region as either an operation result. 1087 auto *operandRegion = operand.getParentRegion(); 1088 if (!region.isAncestor(operandRegion)) { 1089 return op.emitOpError("using value defined outside the region") 1090 .attachNote(isolatedOp->getLoc()) 1091 << "required by region isolation constraints"; 1092 } 1093 } 1094 1095 // Schedule any regions in the operation for further checking. Don't 1096 // recurse into other IsolatedFromAbove ops, because they will check 1097 // themselves. 1098 if (op.getNumRegions() && 1099 !op.hasTrait<OpTrait::IsIsolatedFromAbove>()) { 1100 for (Region &subRegion : op.getRegions()) 1101 pendingRegions.push_back(&subRegion); 1102 } 1103 } 1104 } 1105 } 1106 1107 return success(); 1108 } 1109 1110 bool OpTrait::hasElementwiseMappableTraits(Operation *op) { 1111 return op->hasTrait<Elementwise>() && op->hasTrait<Scalarizable>() && 1112 op->hasTrait<Vectorizable>() && op->hasTrait<Tensorizable>(); 1113 } 1114 1115 //===----------------------------------------------------------------------===// 1116 // BinaryOp implementation 1117 //===----------------------------------------------------------------------===// 1118 1119 // These functions are out-of-line implementations of the methods in BinaryOp, 1120 // which avoids them being template instantiated/duplicated. 1121 1122 void impl::buildBinaryOp(OpBuilder &builder, OperationState &result, Value lhs, 1123 Value rhs) { 1124 assert(lhs.getType() == rhs.getType()); 1125 result.addOperands({lhs, rhs}); 1126 result.types.push_back(lhs.getType()); 1127 } 1128 1129 ParseResult impl::parseOneResultSameOperandTypeOp(OpAsmParser &parser, 1130 OperationState &result) { 1131 SmallVector<OpAsmParser::OperandType, 2> ops; 1132 Type type; 1133 // If the operand list is in-between parentheses, then we have a generic form. 1134 // (see the fallback in `printOneResultOp`). 1135 llvm::SMLoc loc = parser.getCurrentLocation(); 1136 if (!parser.parseOptionalLParen()) { 1137 if (parser.parseOperandList(ops) || parser.parseRParen() || 1138 parser.parseOptionalAttrDict(result.attributes) || 1139 parser.parseColon() || parser.parseType(type)) 1140 return failure(); 1141 auto fnType = type.dyn_cast<FunctionType>(); 1142 if (!fnType) { 1143 parser.emitError(loc, "expected function type"); 1144 return failure(); 1145 } 1146 if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands)) 1147 return failure(); 1148 result.addTypes(fnType.getResults()); 1149 return success(); 1150 } 1151 return failure(parser.parseOperandList(ops) || 1152 parser.parseOptionalAttrDict(result.attributes) || 1153 parser.parseColonType(type) || 1154 parser.resolveOperands(ops, type, result.operands) || 1155 parser.addTypeToList(type, result.types)); 1156 } 1157 1158 void impl::printOneResultOp(Operation *op, OpAsmPrinter &p) { 1159 assert(op->getNumResults() == 1 && "op should have one result"); 1160 1161 // If not all the operand and result types are the same, just use the 1162 // generic assembly form to avoid omitting information in printing. 1163 auto resultType = op->getResult(0).getType(); 1164 if (llvm::any_of(op->getOperandTypes(), 1165 [&](Type type) { return type != resultType; })) { 1166 p.printGenericOp(op, /*printOpName=*/false); 1167 return; 1168 } 1169 1170 p << ' '; 1171 p.printOperands(op->getOperands()); 1172 p.printOptionalAttrDict(op->getAttrs()); 1173 // Now we can output only one type for all operands and the result. 1174 p << " : " << resultType; 1175 } 1176 1177 //===----------------------------------------------------------------------===// 1178 // CastOp implementation 1179 //===----------------------------------------------------------------------===// 1180 1181 /// Attempt to fold the given cast operation. 1182 LogicalResult 1183 impl::foldCastInterfaceOp(Operation *op, ArrayRef<Attribute> attrOperands, 1184 SmallVectorImpl<OpFoldResult> &foldResults) { 1185 OperandRange operands = op->getOperands(); 1186 if (operands.empty()) 1187 return failure(); 1188 ResultRange results = op->getResults(); 1189 1190 // Check for the case where the input and output types match 1-1. 1191 if (operands.getTypes() == results.getTypes()) { 1192 foldResults.append(operands.begin(), operands.end()); 1193 return success(); 1194 } 1195 1196 return failure(); 1197 } 1198 1199 /// Attempt to verify the given cast operation. 1200 LogicalResult impl::verifyCastInterfaceOp( 1201 Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible) { 1202 auto resultTypes = op->getResultTypes(); 1203 if (llvm::empty(resultTypes)) 1204 return op->emitOpError() 1205 << "expected at least one result for cast operation"; 1206 1207 auto operandTypes = op->getOperandTypes(); 1208 if (!areCastCompatible(operandTypes, resultTypes)) { 1209 InFlightDiagnostic diag = op->emitOpError("operand type"); 1210 if (llvm::empty(operandTypes)) 1211 diag << "s []"; 1212 else if (llvm::size(operandTypes) == 1) 1213 diag << " " << *operandTypes.begin(); 1214 else 1215 diag << "s " << operandTypes; 1216 return diag << " and result type" << (resultTypes.size() == 1 ? " " : "s ") 1217 << resultTypes << " are cast incompatible"; 1218 } 1219 1220 return success(); 1221 } 1222 1223 void impl::buildCastOp(OpBuilder &builder, OperationState &result, Value source, 1224 Type destType) { 1225 result.addOperands(source); 1226 result.addTypes(destType); 1227 } 1228 1229 ParseResult impl::parseCastOp(OpAsmParser &parser, OperationState &result) { 1230 OpAsmParser::OperandType srcInfo; 1231 Type srcType, dstType; 1232 return failure(parser.parseOperand(srcInfo) || 1233 parser.parseOptionalAttrDict(result.attributes) || 1234 parser.parseColonType(srcType) || 1235 parser.resolveOperand(srcInfo, srcType, result.operands) || 1236 parser.parseKeywordType("to", dstType) || 1237 parser.addTypeToList(dstType, result.types)); 1238 } 1239 1240 void impl::printCastOp(Operation *op, OpAsmPrinter &p) { 1241 p << ' ' << op->getOperand(0); 1242 p.printOptionalAttrDict(op->getAttrs()); 1243 p << " : " << op->getOperand(0).getType() << " to " 1244 << op->getResult(0).getType(); 1245 } 1246 1247 Value impl::foldCastOp(Operation *op) { 1248 // Identity cast 1249 if (op->getOperand(0).getType() == op->getResult(0).getType()) 1250 return op->getOperand(0); 1251 return nullptr; 1252 } 1253 1254 LogicalResult 1255 impl::verifyCastOp(Operation *op, 1256 function_ref<bool(Type, Type)> areCastCompatible) { 1257 auto opType = op->getOperand(0).getType(); 1258 auto resType = op->getResult(0).getType(); 1259 if (!areCastCompatible(opType, resType)) 1260 return op->emitError("operand type ") 1261 << opType << " and result type " << resType 1262 << " are cast incompatible"; 1263 1264 return success(); 1265 } 1266 1267 //===----------------------------------------------------------------------===// 1268 // Misc. utils 1269 //===----------------------------------------------------------------------===// 1270 1271 /// Insert an operation, generated by `buildTerminatorOp`, at the end of the 1272 /// region's only block if it does not have a terminator already. If the region 1273 /// is empty, insert a new block first. `buildTerminatorOp` should return the 1274 /// terminator operation to insert. 1275 void impl::ensureRegionTerminator( 1276 Region ®ion, OpBuilder &builder, Location loc, 1277 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { 1278 OpBuilder::InsertionGuard guard(builder); 1279 if (region.empty()) 1280 builder.createBlock(®ion); 1281 1282 Block &block = region.back(); 1283 if (!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>()) 1284 return; 1285 1286 builder.setInsertionPointToEnd(&block); 1287 builder.insert(buildTerminatorOp(builder, loc)); 1288 } 1289 1290 /// Create a simple OpBuilder and forward to the OpBuilder version of this 1291 /// function. 1292 void impl::ensureRegionTerminator( 1293 Region ®ion, Builder &builder, Location loc, 1294 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { 1295 OpBuilder opBuilder(builder.getContext()); 1296 ensureRegionTerminator(region, opBuilder, loc, buildTerminatorOp); 1297 } 1298