1 //===- Operation.cpp - Operation support code -----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Operation.h" 10 #include "mlir/IR/BlockAndValueMapping.h" 11 #include "mlir/IR/BuiltinTypes.h" 12 #include "mlir/IR/Dialect.h" 13 #include "mlir/IR/OpImplementation.h" 14 #include "mlir/IR/PatternMatch.h" 15 #include "mlir/IR/TypeUtilities.h" 16 #include "mlir/Interfaces/FoldInterfaces.h" 17 #include "llvm/ADT/StringExtras.h" 18 #include <numeric> 19 20 using namespace mlir; 21 22 //===----------------------------------------------------------------------===// 23 // OperationName 24 //===----------------------------------------------------------------------===// 25 26 /// Form the OperationName for an op with the specified string. This either is 27 /// a reference to an AbstractOperation if one is known, or a uniqued Identifier 28 /// if not. 29 OperationName::OperationName(StringRef name, MLIRContext *context) { 30 if (auto *op = AbstractOperation::lookup(name, context)) 31 representation = op; 32 else 33 representation = Identifier::get(name, context); 34 } 35 36 /// Return the name of the dialect this operation is registered to. 37 StringRef OperationName::getDialectNamespace() const { 38 if (Dialect *dialect = getDialect()) 39 return dialect->getNamespace(); 40 return getStringRef().split('.').first; 41 } 42 43 /// Return the operation name with dialect name stripped, if it has one. 44 StringRef OperationName::stripDialect() const { 45 return getStringRef().split('.').second; 46 } 47 48 /// Return the name of this operation. This always succeeds. 49 StringRef OperationName::getStringRef() const { 50 return getIdentifier().strref(); 51 } 52 53 /// Return the name of this operation as an identifier. This always succeeds. 54 Identifier OperationName::getIdentifier() const { 55 if (auto *op = representation.dyn_cast<const AbstractOperation *>()) 56 return op->name; 57 return representation.get<Identifier>(); 58 } 59 60 OperationName OperationName::getFromOpaquePointer(const void *pointer) { 61 return OperationName( 62 RepresentationUnion::getFromOpaqueValue(const_cast<void *>(pointer))); 63 } 64 65 //===----------------------------------------------------------------------===// 66 // Operation 67 //===----------------------------------------------------------------------===// 68 69 /// Create a new Operation with the specific fields. 70 Operation *Operation::create(Location location, OperationName name, 71 TypeRange resultTypes, ValueRange operands, 72 ArrayRef<NamedAttribute> attributes, 73 BlockRange successors, unsigned numRegions) { 74 return create(location, name, resultTypes, operands, 75 DictionaryAttr::get(location.getContext(), attributes), 76 successors, numRegions); 77 } 78 79 /// Create a new Operation from operation state. 80 Operation *Operation::create(const OperationState &state) { 81 return create(state.location, state.name, state.types, state.operands, 82 state.attributes.getDictionary(state.getContext()), 83 state.successors, state.regions); 84 } 85 86 /// Create a new Operation with the specific fields. 87 Operation *Operation::create(Location location, OperationName name, 88 TypeRange resultTypes, ValueRange operands, 89 DictionaryAttr attributes, BlockRange successors, 90 RegionRange regions) { 91 unsigned numRegions = regions.size(); 92 Operation *op = create(location, name, resultTypes, operands, attributes, 93 successors, numRegions); 94 for (unsigned i = 0; i < numRegions; ++i) 95 if (regions[i]) 96 op->getRegion(i).takeBody(*regions[i]); 97 return op; 98 } 99 100 /// Overload of create that takes an existing DictionaryAttr to avoid 101 /// unnecessarily uniquing a list of attributes. 102 Operation *Operation::create(Location location, OperationName name, 103 TypeRange resultTypes, ValueRange operands, 104 DictionaryAttr attributes, BlockRange successors, 105 unsigned numRegions) { 106 assert(llvm::all_of(resultTypes, [](Type t) { return t; }) && 107 "unexpected null result type"); 108 109 // We only need to allocate additional memory for a subset of results. 110 unsigned numTrailingResults = OpResult::getNumTrailing(resultTypes.size()); 111 unsigned numInlineResults = OpResult::getNumInline(resultTypes.size()); 112 unsigned numSuccessors = successors.size(); 113 unsigned numOperands = operands.size(); 114 unsigned numResults = resultTypes.size(); 115 116 // If the operation is known to have no operands, don't allocate an operand 117 // storage. 118 bool needsOperandStorage = true; 119 if (operands.empty()) { 120 if (const AbstractOperation *abstractOp = name.getAbstractOperation()) 121 needsOperandStorage = !abstractOp->hasTrait<OpTrait::ZeroOperands>(); 122 } 123 124 // Compute the byte size for the operation and the operand storage. This takes 125 // into account the size of the operation, its trailing objects, and its 126 // prefixed objects. 127 size_t byteSize = 128 totalSizeToAlloc<BlockOperand, Region, detail::OperandStorage>( 129 numSuccessors, numRegions, needsOperandStorage ? 1 : 0) + 130 detail::OperandStorage::additionalAllocSize(numOperands); 131 size_t prefixByteSize = llvm::alignTo( 132 Operation::prefixAllocSize(numTrailingResults, numInlineResults), 133 alignof(Operation)); 134 char *mallocMem = reinterpret_cast<char *>(malloc(byteSize + prefixByteSize)); 135 void *rawMem = mallocMem + prefixByteSize; 136 137 // Create the new Operation. 138 Operation *op = 139 ::new (rawMem) Operation(location, name, numResults, numSuccessors, 140 numRegions, attributes, needsOperandStorage); 141 142 assert((numSuccessors == 0 || op->mightHaveTrait<OpTrait::IsTerminator>()) && 143 "unexpected successors in a non-terminator operation"); 144 145 // Initialize the results. 146 auto resultTypeIt = resultTypes.begin(); 147 for (unsigned i = 0; i < numInlineResults; ++i, ++resultTypeIt) 148 new (op->getInlineOpResult(i)) detail::InlineOpResult(*resultTypeIt, i); 149 for (unsigned i = 0; i < numTrailingResults; ++i, ++resultTypeIt) { 150 new (op->getOutOfLineOpResult(i)) 151 detail::OutOfLineOpResult(*resultTypeIt, i); 152 } 153 154 // Initialize the regions. 155 for (unsigned i = 0; i != numRegions; ++i) 156 new (&op->getRegion(i)) Region(op); 157 158 // Initialize the operands. 159 if (needsOperandStorage) 160 new (&op->getOperandStorage()) detail::OperandStorage(op, operands); 161 162 // Initialize the successors. 163 auto blockOperands = op->getBlockOperands(); 164 for (unsigned i = 0; i != numSuccessors; ++i) 165 new (&blockOperands[i]) BlockOperand(op, successors[i]); 166 167 return op; 168 } 169 170 Operation::Operation(Location location, OperationName name, unsigned numResults, 171 unsigned numSuccessors, unsigned numRegions, 172 DictionaryAttr attributes, bool hasOperandStorage) 173 : location(location), numResults(numResults), numSuccs(numSuccessors), 174 numRegions(numRegions), hasOperandStorage(hasOperandStorage), name(name), 175 attrs(attributes) { 176 assert(attributes && "unexpected null attribute dictionary"); 177 #ifndef NDEBUG 178 if (!getDialect() && !getContext()->allowsUnregisteredDialects()) 179 llvm::report_fatal_error( 180 name.getStringRef() + 181 " created with unregistered dialect. If this is intended, please call " 182 "allowUnregisteredDialects() on the MLIRContext, or use " 183 "-allow-unregistered-dialect with the MLIR opt tool used"); 184 #endif 185 } 186 187 // Operations are deleted through the destroy() member because they are 188 // allocated via malloc. 189 Operation::~Operation() { 190 assert(block == nullptr && "operation destroyed but still in a block"); 191 #ifndef NDEBUG 192 if (!use_empty()) { 193 { 194 InFlightDiagnostic diag = 195 emitOpError("operation destroyed but still has uses"); 196 for (Operation *user : getUsers()) 197 diag.attachNote(user->getLoc()) << "- use: " << *user << "\n"; 198 } 199 llvm::report_fatal_error("operation destroyed but still has uses"); 200 } 201 #endif 202 // Explicitly run the destructors for the operands. 203 if (hasOperandStorage) 204 getOperandStorage().~OperandStorage(); 205 206 // Explicitly run the destructors for the successors. 207 for (auto &successor : getBlockOperands()) 208 successor.~BlockOperand(); 209 210 // Explicitly destroy the regions. 211 for (auto ®ion : getRegions()) 212 region.~Region(); 213 } 214 215 /// Destroy this operation or one of its subclasses. 216 void Operation::destroy() { 217 // Operations may have additional prefixed allocation, which needs to be 218 // accounted for here when computing the address to free. 219 char *rawMem = reinterpret_cast<char *>(this) - 220 llvm::alignTo(prefixAllocSize(), alignof(Operation)); 221 this->~Operation(); 222 free(rawMem); 223 } 224 225 /// Return true if this operation is a proper ancestor of the `other` 226 /// operation. 227 bool Operation::isProperAncestor(Operation *other) { 228 while ((other = other->getParentOp())) 229 if (this == other) 230 return true; 231 return false; 232 } 233 234 /// Replace any uses of 'from' with 'to' within this operation. 235 void Operation::replaceUsesOfWith(Value from, Value to) { 236 if (from == to) 237 return; 238 for (auto &operand : getOpOperands()) 239 if (operand.get() == from) 240 operand.set(to); 241 } 242 243 /// Replace the current operands of this operation with the ones provided in 244 /// 'operands'. 245 void Operation::setOperands(ValueRange operands) { 246 if (LLVM_LIKELY(hasOperandStorage)) 247 return getOperandStorage().setOperands(this, operands); 248 assert(operands.empty() && "setting operands without an operand storage"); 249 } 250 251 /// Replace the operands beginning at 'start' and ending at 'start' + 'length' 252 /// with the ones provided in 'operands'. 'operands' may be smaller or larger 253 /// than the range pointed to by 'start'+'length'. 254 void Operation::setOperands(unsigned start, unsigned length, 255 ValueRange operands) { 256 assert((start + length) <= getNumOperands() && 257 "invalid operand range specified"); 258 if (LLVM_LIKELY(hasOperandStorage)) 259 return getOperandStorage().setOperands(this, start, length, operands); 260 assert(operands.empty() && "setting operands without an operand storage"); 261 } 262 263 /// Insert the given operands into the operand list at the given 'index'. 264 void Operation::insertOperands(unsigned index, ValueRange operands) { 265 if (LLVM_LIKELY(hasOperandStorage)) 266 return setOperands(index, /*length=*/0, operands); 267 assert(operands.empty() && "inserting operands without an operand storage"); 268 } 269 270 //===----------------------------------------------------------------------===// 271 // Diagnostics 272 //===----------------------------------------------------------------------===// 273 274 /// Emit an error about fatal conditions with this operation, reporting up to 275 /// any diagnostic handlers that may be listening. 276 InFlightDiagnostic Operation::emitError(const Twine &message) { 277 InFlightDiagnostic diag = mlir::emitError(getLoc(), message); 278 if (getContext()->shouldPrintOpOnDiagnostic()) { 279 // Print out the operation explicitly here so that we can print the generic 280 // form. 281 // TODO: It would be nice if we could instead provide the 282 // specific printing flags when adding the operation as an argument to the 283 // diagnostic. 284 std::string printedOp; 285 { 286 llvm::raw_string_ostream os(printedOp); 287 print(os, OpPrintingFlags().printGenericOpForm().useLocalScope()); 288 } 289 diag.attachNote(getLoc()) << "see current operation: " << printedOp; 290 } 291 return diag; 292 } 293 294 /// Emit a warning about this operation, reporting up to any diagnostic 295 /// handlers that may be listening. 296 InFlightDiagnostic Operation::emitWarning(const Twine &message) { 297 InFlightDiagnostic diag = mlir::emitWarning(getLoc(), message); 298 if (getContext()->shouldPrintOpOnDiagnostic()) 299 diag.attachNote(getLoc()) << "see current operation: " << *this; 300 return diag; 301 } 302 303 /// Emit a remark about this operation, reporting up to any diagnostic 304 /// handlers that may be listening. 305 InFlightDiagnostic Operation::emitRemark(const Twine &message) { 306 InFlightDiagnostic diag = mlir::emitRemark(getLoc(), message); 307 if (getContext()->shouldPrintOpOnDiagnostic()) 308 diag.attachNote(getLoc()) << "see current operation: " << *this; 309 return diag; 310 } 311 312 //===----------------------------------------------------------------------===// 313 // Operation Ordering 314 //===----------------------------------------------------------------------===// 315 316 constexpr unsigned Operation::kInvalidOrderIdx; 317 constexpr unsigned Operation::kOrderStride; 318 319 /// Given an operation 'other' that is within the same parent block, return 320 /// whether the current operation is before 'other' in the operation list 321 /// of the parent block. 322 /// Note: This function has an average complexity of O(1), but worst case may 323 /// take O(N) where N is the number of operations within the parent block. 324 bool Operation::isBeforeInBlock(Operation *other) { 325 assert(block && "Operations without parent blocks have no order."); 326 assert(other && other->block == block && 327 "Expected other operation to have the same parent block."); 328 // If the order of the block is already invalid, directly recompute the 329 // parent. 330 if (!block->isOpOrderValid()) { 331 block->recomputeOpOrder(); 332 } else { 333 // Update the order either operation if necessary. 334 updateOrderIfNecessary(); 335 other->updateOrderIfNecessary(); 336 } 337 338 return orderIndex < other->orderIndex; 339 } 340 341 /// Update the order index of this operation of this operation if necessary, 342 /// potentially recomputing the order of the parent block. 343 void Operation::updateOrderIfNecessary() { 344 assert(block && "expected valid parent"); 345 346 // If the order is valid for this operation there is nothing to do. 347 if (hasValidOrder()) 348 return; 349 Operation *blockFront = &block->front(); 350 Operation *blockBack = &block->back(); 351 352 // This method is expected to only be invoked on blocks with more than one 353 // operation. 354 assert(blockFront != blockBack && "expected more than one operation"); 355 356 // If the operation is at the end of the block. 357 if (this == blockBack) { 358 Operation *prevNode = getPrevNode(); 359 if (!prevNode->hasValidOrder()) 360 return block->recomputeOpOrder(); 361 362 // Add the stride to the previous operation. 363 orderIndex = prevNode->orderIndex + kOrderStride; 364 return; 365 } 366 367 // If this is the first operation try to use the next operation to compute the 368 // ordering. 369 if (this == blockFront) { 370 Operation *nextNode = getNextNode(); 371 if (!nextNode->hasValidOrder()) 372 return block->recomputeOpOrder(); 373 // There is no order to give this operation. 374 if (nextNode->orderIndex == 0) 375 return block->recomputeOpOrder(); 376 377 // If we can't use the stride, just take the middle value left. This is safe 378 // because we know there is at least one valid index to assign to. 379 if (nextNode->orderIndex <= kOrderStride) 380 orderIndex = (nextNode->orderIndex / 2); 381 else 382 orderIndex = kOrderStride; 383 return; 384 } 385 386 // Otherwise, this operation is between two others. Place this operation in 387 // the middle of the previous and next if possible. 388 Operation *prevNode = getPrevNode(), *nextNode = getNextNode(); 389 if (!prevNode->hasValidOrder() || !nextNode->hasValidOrder()) 390 return block->recomputeOpOrder(); 391 unsigned prevOrder = prevNode->orderIndex, nextOrder = nextNode->orderIndex; 392 393 // Check to see if there is a valid order between the two. 394 if (prevOrder + 1 == nextOrder) 395 return block->recomputeOpOrder(); 396 orderIndex = prevOrder + ((nextOrder - prevOrder) / 2); 397 } 398 399 //===----------------------------------------------------------------------===// 400 // ilist_traits for Operation 401 //===----------------------------------------------------------------------===// 402 403 auto llvm::ilist_detail::SpecificNodeAccess< 404 typename llvm::ilist_detail::compute_node_options< 405 ::mlir::Operation>::type>::getNodePtr(pointer N) -> node_type * { 406 return NodeAccess::getNodePtr<OptionsT>(N); 407 } 408 409 auto llvm::ilist_detail::SpecificNodeAccess< 410 typename llvm::ilist_detail::compute_node_options< 411 ::mlir::Operation>::type>::getNodePtr(const_pointer N) 412 -> const node_type * { 413 return NodeAccess::getNodePtr<OptionsT>(N); 414 } 415 416 auto llvm::ilist_detail::SpecificNodeAccess< 417 typename llvm::ilist_detail::compute_node_options< 418 ::mlir::Operation>::type>::getValuePtr(node_type *N) -> pointer { 419 return NodeAccess::getValuePtr<OptionsT>(N); 420 } 421 422 auto llvm::ilist_detail::SpecificNodeAccess< 423 typename llvm::ilist_detail::compute_node_options< 424 ::mlir::Operation>::type>::getValuePtr(const node_type *N) 425 -> const_pointer { 426 return NodeAccess::getValuePtr<OptionsT>(N); 427 } 428 429 void llvm::ilist_traits<::mlir::Operation>::deleteNode(Operation *op) { 430 op->destroy(); 431 } 432 433 Block *llvm::ilist_traits<::mlir::Operation>::getContainingBlock() { 434 size_t Offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr)))); 435 iplist<Operation> *Anchor(static_cast<iplist<Operation> *>(this)); 436 return reinterpret_cast<Block *>(reinterpret_cast<char *>(Anchor) - Offset); 437 } 438 439 /// This is a trait method invoked when an operation is added to a block. We 440 /// keep the block pointer up to date. 441 void llvm::ilist_traits<::mlir::Operation>::addNodeToList(Operation *op) { 442 assert(!op->getBlock() && "already in an operation block!"); 443 op->block = getContainingBlock(); 444 445 // Invalidate the order on the operation. 446 op->orderIndex = Operation::kInvalidOrderIdx; 447 } 448 449 /// This is a trait method invoked when an operation is removed from a block. 450 /// We keep the block pointer up to date. 451 void llvm::ilist_traits<::mlir::Operation>::removeNodeFromList(Operation *op) { 452 assert(op->block && "not already in an operation block!"); 453 op->block = nullptr; 454 } 455 456 /// This is a trait method invoked when an operation is moved from one block 457 /// to another. We keep the block pointer up to date. 458 void llvm::ilist_traits<::mlir::Operation>::transferNodesFromList( 459 ilist_traits<Operation> &otherList, op_iterator first, op_iterator last) { 460 Block *curParent = getContainingBlock(); 461 462 // Invalidate the ordering of the parent block. 463 curParent->invalidateOpOrder(); 464 465 // If we are transferring operations within the same block, the block 466 // pointer doesn't need to be updated. 467 if (curParent == otherList.getContainingBlock()) 468 return; 469 470 // Update the 'block' member of each operation. 471 for (; first != last; ++first) 472 first->block = curParent; 473 } 474 475 /// Remove this operation (and its descendants) from its Block and delete 476 /// all of them. 477 void Operation::erase() { 478 if (auto *parent = getBlock()) 479 parent->getOperations().erase(this); 480 else 481 destroy(); 482 } 483 484 /// Remove the operation from its parent block, but don't delete it. 485 void Operation::remove() { 486 if (Block *parent = getBlock()) 487 parent->getOperations().remove(this); 488 } 489 490 /// Unlink this operation from its current block and insert it right before 491 /// `existingOp` which may be in the same or another block in the same 492 /// function. 493 void Operation::moveBefore(Operation *existingOp) { 494 moveBefore(existingOp->getBlock(), existingOp->getIterator()); 495 } 496 497 /// Unlink this operation from its current basic block and insert it right 498 /// before `iterator` in the specified basic block. 499 void Operation::moveBefore(Block *block, 500 llvm::iplist<Operation>::iterator iterator) { 501 block->getOperations().splice(iterator, getBlock()->getOperations(), 502 getIterator()); 503 } 504 505 /// Unlink this operation from its current block and insert it right after 506 /// `existingOp` which may be in the same or another block in the same function. 507 void Operation::moveAfter(Operation *existingOp) { 508 moveAfter(existingOp->getBlock(), existingOp->getIterator()); 509 } 510 511 /// Unlink this operation from its current block and insert it right after 512 /// `iterator` in the specified block. 513 void Operation::moveAfter(Block *block, 514 llvm::iplist<Operation>::iterator iterator) { 515 assert(iterator != block->end() && "cannot move after end of block"); 516 moveBefore(&*std::next(iterator)); 517 } 518 519 /// This drops all operand uses from this operation, which is an essential 520 /// step in breaking cyclic dependences between references when they are to 521 /// be deleted. 522 void Operation::dropAllReferences() { 523 for (auto &op : getOpOperands()) 524 op.drop(); 525 526 for (auto ®ion : getRegions()) 527 region.dropAllReferences(); 528 529 for (auto &dest : getBlockOperands()) 530 dest.drop(); 531 } 532 533 /// This drops all uses of any values defined by this operation or its nested 534 /// regions, wherever they are located. 535 void Operation::dropAllDefinedValueUses() { 536 dropAllUses(); 537 538 for (auto ®ion : getRegions()) 539 for (auto &block : region) 540 block.dropAllDefinedValueUses(); 541 } 542 543 void Operation::setSuccessor(Block *block, unsigned index) { 544 assert(index < getNumSuccessors()); 545 getBlockOperands()[index].set(block); 546 } 547 548 /// Attempt to fold this operation using the Op's registered foldHook. 549 LogicalResult Operation::fold(ArrayRef<Attribute> operands, 550 SmallVectorImpl<OpFoldResult> &results) { 551 // If we have a registered operation definition matching this one, use it to 552 // try to constant fold the operation. 553 auto *abstractOp = getAbstractOperation(); 554 if (abstractOp && succeeded(abstractOp->foldHook(this, operands, results))) 555 return success(); 556 557 // Otherwise, fall back on the dialect hook to handle it. 558 Dialect *dialect = getDialect(); 559 if (!dialect) 560 return failure(); 561 562 auto *interface = dialect->getRegisteredInterface<DialectFoldInterface>(); 563 if (!interface) 564 return failure(); 565 566 return interface->fold(this, operands, results); 567 } 568 569 /// Emit an error with the op name prefixed, like "'dim' op " which is 570 /// convenient for verifiers. 571 InFlightDiagnostic Operation::emitOpError(const Twine &message) { 572 return emitError() << "'" << getName() << "' op " << message; 573 } 574 575 //===----------------------------------------------------------------------===// 576 // Operation Cloning 577 //===----------------------------------------------------------------------===// 578 579 /// Create a deep copy of this operation but keep the operation regions empty. 580 /// Operands are remapped using `mapper` (if present), and `mapper` is updated 581 /// to contain the results. 582 Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper) { 583 SmallVector<Value, 8> operands; 584 SmallVector<Block *, 2> successors; 585 586 // Remap the operands. 587 operands.reserve(getNumOperands()); 588 for (auto opValue : getOperands()) 589 operands.push_back(mapper.lookupOrDefault(opValue)); 590 591 // Remap the successors. 592 successors.reserve(getNumSuccessors()); 593 for (Block *successor : getSuccessors()) 594 successors.push_back(mapper.lookupOrDefault(successor)); 595 596 // Create the new operation. 597 auto *newOp = create(getLoc(), getName(), getResultTypes(), operands, attrs, 598 successors, getNumRegions()); 599 600 // Remember the mapping of any results. 601 for (unsigned i = 0, e = getNumResults(); i != e; ++i) 602 mapper.map(getResult(i), newOp->getResult(i)); 603 604 return newOp; 605 } 606 607 Operation *Operation::cloneWithoutRegions() { 608 BlockAndValueMapping mapper; 609 return cloneWithoutRegions(mapper); 610 } 611 612 /// Create a deep copy of this operation, remapping any operands that use 613 /// values outside of the operation using the map that is provided (leaving 614 /// them alone if no entry is present). Replaces references to cloned 615 /// sub-operations to the corresponding operation that is copied, and adds 616 /// those mappings to the map. 617 Operation *Operation::clone(BlockAndValueMapping &mapper) { 618 auto *newOp = cloneWithoutRegions(mapper); 619 620 // Clone the regions. 621 for (unsigned i = 0; i != numRegions; ++i) 622 getRegion(i).cloneInto(&newOp->getRegion(i), mapper); 623 624 return newOp; 625 } 626 627 Operation *Operation::clone() { 628 BlockAndValueMapping mapper; 629 return clone(mapper); 630 } 631 632 //===----------------------------------------------------------------------===// 633 // OpState trait class. 634 //===----------------------------------------------------------------------===// 635 636 // The fallback for the parser is to reject the custom assembly form. 637 ParseResult OpState::parse(OpAsmParser &parser, OperationState &result) { 638 return parser.emitError(parser.getNameLoc(), "has no custom assembly form"); 639 } 640 641 // The fallback for the printer is to print in the generic assembly form. 642 void OpState::print(Operation *op, OpAsmPrinter &p) { p.printGenericOp(op); } 643 // The fallback for the printer is to print in the generic assembly form. 644 void OpState::printOpName(Operation *op, OpAsmPrinter &p, 645 StringRef defaultDialect) { 646 StringRef name = op->getName().getStringRef(); 647 if (name.startswith((defaultDialect + ".").str())) 648 name = name.drop_front(defaultDialect.size() + 1); 649 // TODO: remove this special case. 650 else if (name.startswith("std.")) 651 name = name.drop_front(4); 652 p.getStream() << name; 653 } 654 655 /// Emit an error about fatal conditions with this operation, reporting up to 656 /// any diagnostic handlers that may be listening. 657 InFlightDiagnostic OpState::emitError(const Twine &message) { 658 return getOperation()->emitError(message); 659 } 660 661 /// Emit an error with the op name prefixed, like "'dim' op " which is 662 /// convenient for verifiers. 663 InFlightDiagnostic OpState::emitOpError(const Twine &message) { 664 return getOperation()->emitOpError(message); 665 } 666 667 /// Emit a warning about this operation, reporting up to any diagnostic 668 /// handlers that may be listening. 669 InFlightDiagnostic OpState::emitWarning(const Twine &message) { 670 return getOperation()->emitWarning(message); 671 } 672 673 /// Emit a remark about this operation, reporting up to any diagnostic 674 /// handlers that may be listening. 675 InFlightDiagnostic OpState::emitRemark(const Twine &message) { 676 return getOperation()->emitRemark(message); 677 } 678 679 //===----------------------------------------------------------------------===// 680 // Op Trait implementations 681 //===----------------------------------------------------------------------===// 682 683 OpFoldResult OpTrait::impl::foldIdempotent(Operation *op) { 684 auto *argumentOp = op->getOperand(0).getDefiningOp(); 685 if (argumentOp && op->getName() == argumentOp->getName()) { 686 // Replace the outer operation output with the inner operation. 687 return op->getOperand(0); 688 } 689 690 return {}; 691 } 692 693 OpFoldResult OpTrait::impl::foldInvolution(Operation *op) { 694 auto *argumentOp = op->getOperand(0).getDefiningOp(); 695 if (argumentOp && op->getName() == argumentOp->getName()) { 696 // Replace the outer involutions output with inner's input. 697 return argumentOp->getOperand(0); 698 } 699 700 return {}; 701 } 702 703 LogicalResult OpTrait::impl::verifyZeroOperands(Operation *op) { 704 if (op->getNumOperands() != 0) 705 return op->emitOpError() << "requires zero operands"; 706 return success(); 707 } 708 709 LogicalResult OpTrait::impl::verifyOneOperand(Operation *op) { 710 if (op->getNumOperands() != 1) 711 return op->emitOpError() << "requires a single operand"; 712 return success(); 713 } 714 715 LogicalResult OpTrait::impl::verifyNOperands(Operation *op, 716 unsigned numOperands) { 717 if (op->getNumOperands() != numOperands) { 718 return op->emitOpError() << "expected " << numOperands 719 << " operands, but found " << op->getNumOperands(); 720 } 721 return success(); 722 } 723 724 LogicalResult OpTrait::impl::verifyAtLeastNOperands(Operation *op, 725 unsigned numOperands) { 726 if (op->getNumOperands() < numOperands) 727 return op->emitOpError() 728 << "expected " << numOperands << " or more operands, but found " 729 << op->getNumOperands(); 730 return success(); 731 } 732 733 /// If this is a vector type, or a tensor type, return the scalar element type 734 /// that it is built around, otherwise return the type unmodified. 735 static Type getTensorOrVectorElementType(Type type) { 736 if (auto vec = type.dyn_cast<VectorType>()) 737 return vec.getElementType(); 738 739 // Look through tensor<vector<...>> to find the underlying element type. 740 if (auto tensor = type.dyn_cast<TensorType>()) 741 return getTensorOrVectorElementType(tensor.getElementType()); 742 return type; 743 } 744 745 LogicalResult OpTrait::impl::verifyIsIdempotent(Operation *op) { 746 // FIXME: Add back check for no side effects on operation. 747 // Currently adding it would cause the shared library build 748 // to fail since there would be a dependency of IR on SideEffectInterfaces 749 // which is cyclical. 750 return success(); 751 } 752 753 LogicalResult OpTrait::impl::verifyIsInvolution(Operation *op) { 754 // FIXME: Add back check for no side effects on operation. 755 // Currently adding it would cause the shared library build 756 // to fail since there would be a dependency of IR on SideEffectInterfaces 757 // which is cyclical. 758 return success(); 759 } 760 761 LogicalResult 762 OpTrait::impl::verifyOperandsAreSignlessIntegerLike(Operation *op) { 763 for (auto opType : op->getOperandTypes()) { 764 auto type = getTensorOrVectorElementType(opType); 765 if (!type.isSignlessIntOrIndex()) 766 return op->emitOpError() << "requires an integer or index type"; 767 } 768 return success(); 769 } 770 771 LogicalResult OpTrait::impl::verifyOperandsAreFloatLike(Operation *op) { 772 for (auto opType : op->getOperandTypes()) { 773 auto type = getTensorOrVectorElementType(opType); 774 if (!type.isa<FloatType>()) 775 return op->emitOpError("requires a float type"); 776 } 777 return success(); 778 } 779 780 LogicalResult OpTrait::impl::verifySameTypeOperands(Operation *op) { 781 // Zero or one operand always have the "same" type. 782 unsigned nOperands = op->getNumOperands(); 783 if (nOperands < 2) 784 return success(); 785 786 auto type = op->getOperand(0).getType(); 787 for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) 788 if (opType != type) 789 return op->emitOpError() << "requires all operands to have the same type"; 790 return success(); 791 } 792 793 LogicalResult OpTrait::impl::verifyZeroRegion(Operation *op) { 794 if (op->getNumRegions() != 0) 795 return op->emitOpError() << "requires zero regions"; 796 return success(); 797 } 798 799 LogicalResult OpTrait::impl::verifyOneRegion(Operation *op) { 800 if (op->getNumRegions() != 1) 801 return op->emitOpError() << "requires one region"; 802 return success(); 803 } 804 805 LogicalResult OpTrait::impl::verifyNRegions(Operation *op, 806 unsigned numRegions) { 807 if (op->getNumRegions() != numRegions) 808 return op->emitOpError() << "expected " << numRegions << " regions"; 809 return success(); 810 } 811 812 LogicalResult OpTrait::impl::verifyAtLeastNRegions(Operation *op, 813 unsigned numRegions) { 814 if (op->getNumRegions() < numRegions) 815 return op->emitOpError() << "expected " << numRegions << " or more regions"; 816 return success(); 817 } 818 819 LogicalResult OpTrait::impl::verifyZeroResult(Operation *op) { 820 if (op->getNumResults() != 0) 821 return op->emitOpError() << "requires zero results"; 822 return success(); 823 } 824 825 LogicalResult OpTrait::impl::verifyOneResult(Operation *op) { 826 if (op->getNumResults() != 1) 827 return op->emitOpError() << "requires one result"; 828 return success(); 829 } 830 831 LogicalResult OpTrait::impl::verifyNResults(Operation *op, 832 unsigned numOperands) { 833 if (op->getNumResults() != numOperands) 834 return op->emitOpError() << "expected " << numOperands << " results"; 835 return success(); 836 } 837 838 LogicalResult OpTrait::impl::verifyAtLeastNResults(Operation *op, 839 unsigned numOperands) { 840 if (op->getNumResults() < numOperands) 841 return op->emitOpError() 842 << "expected " << numOperands << " or more results"; 843 return success(); 844 } 845 846 LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) { 847 if (failed(verifyAtLeastNOperands(op, 1))) 848 return failure(); 849 850 if (failed(verifyCompatibleShapes(op->getOperandTypes()))) 851 return op->emitOpError() << "requires the same shape for all operands"; 852 853 return success(); 854 } 855 856 LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) { 857 if (failed(verifyAtLeastNOperands(op, 1)) || 858 failed(verifyAtLeastNResults(op, 1))) 859 return failure(); 860 861 SmallVector<Type, 8> types(op->getOperandTypes()); 862 types.append(llvm::to_vector<4>(op->getResultTypes())); 863 864 if (failed(verifyCompatibleShapes(types))) 865 return op->emitOpError() 866 << "requires the same shape for all operands and results"; 867 868 return success(); 869 } 870 871 LogicalResult OpTrait::impl::verifySameOperandsElementType(Operation *op) { 872 if (failed(verifyAtLeastNOperands(op, 1))) 873 return failure(); 874 auto elementType = getElementTypeOrSelf(op->getOperand(0)); 875 876 for (auto operand : llvm::drop_begin(op->getOperands(), 1)) { 877 if (getElementTypeOrSelf(operand) != elementType) 878 return op->emitOpError("requires the same element type for all operands"); 879 } 880 881 return success(); 882 } 883 884 LogicalResult 885 OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) { 886 if (failed(verifyAtLeastNOperands(op, 1)) || 887 failed(verifyAtLeastNResults(op, 1))) 888 return failure(); 889 890 auto elementType = getElementTypeOrSelf(op->getResult(0)); 891 892 // Verify result element type matches first result's element type. 893 for (auto result : llvm::drop_begin(op->getResults(), 1)) { 894 if (getElementTypeOrSelf(result) != elementType) 895 return op->emitOpError( 896 "requires the same element type for all operands and results"); 897 } 898 899 // Verify operand's element type matches first result's element type. 900 for (auto operand : op->getOperands()) { 901 if (getElementTypeOrSelf(operand) != elementType) 902 return op->emitOpError( 903 "requires the same element type for all operands and results"); 904 } 905 906 return success(); 907 } 908 909 LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) { 910 if (failed(verifyAtLeastNOperands(op, 1)) || 911 failed(verifyAtLeastNResults(op, 1))) 912 return failure(); 913 914 auto type = op->getResult(0).getType(); 915 auto elementType = getElementTypeOrSelf(type); 916 for (auto resultType : llvm::drop_begin(op->getResultTypes())) { 917 if (getElementTypeOrSelf(resultType) != elementType || 918 failed(verifyCompatibleShape(resultType, type))) 919 return op->emitOpError() 920 << "requires the same type for all operands and results"; 921 } 922 for (auto opType : op->getOperandTypes()) { 923 if (getElementTypeOrSelf(opType) != elementType || 924 failed(verifyCompatibleShape(opType, type))) 925 return op->emitOpError() 926 << "requires the same type for all operands and results"; 927 } 928 return success(); 929 } 930 931 LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) { 932 Block *block = op->getBlock(); 933 // Verify that the operation is at the end of the respective parent block. 934 if (!block || &block->back() != op) 935 return op->emitOpError("must be the last operation in the parent block"); 936 return success(); 937 } 938 939 static LogicalResult verifyTerminatorSuccessors(Operation *op) { 940 auto *parent = op->getParentRegion(); 941 942 // Verify that the operands lines up with the BB arguments in the successor. 943 for (Block *succ : op->getSuccessors()) 944 if (succ->getParent() != parent) 945 return op->emitError("reference to block defined in another region"); 946 return success(); 947 } 948 949 LogicalResult OpTrait::impl::verifyZeroSuccessor(Operation *op) { 950 if (op->getNumSuccessors() != 0) { 951 return op->emitOpError("requires 0 successors but found ") 952 << op->getNumSuccessors(); 953 } 954 return success(); 955 } 956 957 LogicalResult OpTrait::impl::verifyOneSuccessor(Operation *op) { 958 if (op->getNumSuccessors() != 1) { 959 return op->emitOpError("requires 1 successor but found ") 960 << op->getNumSuccessors(); 961 } 962 return verifyTerminatorSuccessors(op); 963 } 964 LogicalResult OpTrait::impl::verifyNSuccessors(Operation *op, 965 unsigned numSuccessors) { 966 if (op->getNumSuccessors() != numSuccessors) { 967 return op->emitOpError("requires ") 968 << numSuccessors << " successors but found " 969 << op->getNumSuccessors(); 970 } 971 return verifyTerminatorSuccessors(op); 972 } 973 LogicalResult OpTrait::impl::verifyAtLeastNSuccessors(Operation *op, 974 unsigned numSuccessors) { 975 if (op->getNumSuccessors() < numSuccessors) { 976 return op->emitOpError("requires at least ") 977 << numSuccessors << " successors but found " 978 << op->getNumSuccessors(); 979 } 980 return verifyTerminatorSuccessors(op); 981 } 982 983 LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) { 984 for (auto resultType : op->getResultTypes()) { 985 auto elementType = getTensorOrVectorElementType(resultType); 986 bool isBoolType = elementType.isInteger(1); 987 if (!isBoolType) 988 return op->emitOpError() << "requires a bool result type"; 989 } 990 991 return success(); 992 } 993 994 LogicalResult OpTrait::impl::verifyResultsAreFloatLike(Operation *op) { 995 for (auto resultType : op->getResultTypes()) 996 if (!getTensorOrVectorElementType(resultType).isa<FloatType>()) 997 return op->emitOpError() << "requires a floating point type"; 998 999 return success(); 1000 } 1001 1002 LogicalResult 1003 OpTrait::impl::verifyResultsAreSignlessIntegerLike(Operation *op) { 1004 for (auto resultType : op->getResultTypes()) 1005 if (!getTensorOrVectorElementType(resultType).isSignlessIntOrIndex()) 1006 return op->emitOpError() << "requires an integer or index type"; 1007 return success(); 1008 } 1009 1010 LogicalResult OpTrait::impl::verifyValueSizeAttr(Operation *op, 1011 StringRef attrName, 1012 StringRef valueGroupName, 1013 size_t expectedCount) { 1014 auto sizeAttr = op->getAttrOfType<DenseIntElementsAttr>(attrName); 1015 if (!sizeAttr) 1016 return op->emitOpError("requires 1D i32 elements attribute '") 1017 << attrName << "'"; 1018 1019 auto sizeAttrType = sizeAttr.getType(); 1020 if (sizeAttrType.getRank() != 1 || 1021 !sizeAttrType.getElementType().isInteger(32)) 1022 return op->emitOpError("requires 1D i32 elements attribute '") 1023 << attrName << "'"; 1024 1025 if (llvm::any_of(sizeAttr.getValues<APInt>(), [](const APInt &element) { 1026 return !element.isNonNegative(); 1027 })) 1028 return op->emitOpError("'") 1029 << attrName << "' attribute cannot have negative elements"; 1030 1031 size_t totalCount = std::accumulate( 1032 sizeAttr.begin(), sizeAttr.end(), 0, 1033 [](unsigned all, APInt one) { return all + one.getZExtValue(); }); 1034 1035 if (totalCount != expectedCount) 1036 return op->emitOpError() 1037 << valueGroupName << " count (" << expectedCount 1038 << ") does not match with the total size (" << totalCount 1039 << ") specified in attribute '" << attrName << "'"; 1040 return success(); 1041 } 1042 1043 LogicalResult OpTrait::impl::verifyOperandSizeAttr(Operation *op, 1044 StringRef attrName) { 1045 return verifyValueSizeAttr(op, attrName, "operand", op->getNumOperands()); 1046 } 1047 1048 LogicalResult OpTrait::impl::verifyResultSizeAttr(Operation *op, 1049 StringRef attrName) { 1050 return verifyValueSizeAttr(op, attrName, "result", op->getNumResults()); 1051 } 1052 1053 LogicalResult OpTrait::impl::verifyNoRegionArguments(Operation *op) { 1054 for (Region ®ion : op->getRegions()) { 1055 if (region.empty()) 1056 continue; 1057 1058 if (region.getNumArguments() != 0) { 1059 if (op->getNumRegions() > 1) 1060 return op->emitOpError("region #") 1061 << region.getRegionNumber() << " should have no arguments"; 1062 else 1063 return op->emitOpError("region should have no arguments"); 1064 } 1065 } 1066 return success(); 1067 } 1068 1069 LogicalResult OpTrait::impl::verifyElementwise(Operation *op) { 1070 auto isMappableType = [](Type type) { 1071 return type.isa<VectorType, TensorType>(); 1072 }; 1073 auto resultMappableTypes = llvm::to_vector<1>( 1074 llvm::make_filter_range(op->getResultTypes(), isMappableType)); 1075 auto operandMappableTypes = llvm::to_vector<2>( 1076 llvm::make_filter_range(op->getOperandTypes(), isMappableType)); 1077 1078 // If the op only has scalar operand/result types, then we have nothing to 1079 // check. 1080 if (resultMappableTypes.empty() && operandMappableTypes.empty()) 1081 return success(); 1082 1083 if (!resultMappableTypes.empty() && operandMappableTypes.empty()) 1084 return op->emitOpError("if a result is non-scalar, then at least one " 1085 "operand must be non-scalar"); 1086 1087 assert(!operandMappableTypes.empty()); 1088 1089 if (resultMappableTypes.empty()) 1090 return op->emitOpError("if an operand is non-scalar, then there must be at " 1091 "least one non-scalar result"); 1092 1093 if (resultMappableTypes.size() != op->getNumResults()) 1094 return op->emitOpError( 1095 "if an operand is non-scalar, then all results must be non-scalar"); 1096 1097 SmallVector<Type, 4> types = llvm::to_vector<2>( 1098 llvm::concat<Type>(operandMappableTypes, resultMappableTypes)); 1099 TypeID expectedBaseTy = types.front().getTypeID(); 1100 if (!llvm::all_of(types, 1101 [&](Type t) { return t.getTypeID() == expectedBaseTy; }) || 1102 failed(verifyCompatibleShapes(types))) { 1103 return op->emitOpError() << "all non-scalar operands/results must have the " 1104 "same shape and base type"; 1105 } 1106 1107 return success(); 1108 } 1109 1110 /// Check for any values used by operations regions attached to the 1111 /// specified "IsIsolatedFromAbove" operation defined outside of it. 1112 LogicalResult OpTrait::impl::verifyIsIsolatedFromAbove(Operation *isolatedOp) { 1113 assert(isolatedOp->hasTrait<OpTrait::IsIsolatedFromAbove>() && 1114 "Intended to check IsolatedFromAbove ops"); 1115 1116 // List of regions to analyze. Each region is processed independently, with 1117 // respect to the common `limit` region, so we can look at them in any order. 1118 // Therefore, use a simple vector and push/pop back the current region. 1119 SmallVector<Region *, 8> pendingRegions; 1120 for (auto ®ion : isolatedOp->getRegions()) { 1121 pendingRegions.push_back(®ion); 1122 1123 // Traverse all operations in the region. 1124 while (!pendingRegions.empty()) { 1125 for (Operation &op : pendingRegions.pop_back_val()->getOps()) { 1126 for (Value operand : op.getOperands()) { 1127 // operand should be non-null here if the IR is well-formed. But 1128 // we don't assert here as this function is called from the verifier 1129 // and so could be called on invalid IR. 1130 if (!operand) 1131 return op.emitOpError("operation's operand is null"); 1132 1133 // Check that any value that is used by an operation is defined in the 1134 // same region as either an operation result. 1135 auto *operandRegion = operand.getParentRegion(); 1136 if (!region.isAncestor(operandRegion)) { 1137 return op.emitOpError("using value defined outside the region") 1138 .attachNote(isolatedOp->getLoc()) 1139 << "required by region isolation constraints"; 1140 } 1141 } 1142 1143 // Schedule any regions in the operation for further checking. Don't 1144 // recurse into other IsolatedFromAbove ops, because they will check 1145 // themselves. 1146 if (op.getNumRegions() && 1147 !op.hasTrait<OpTrait::IsIsolatedFromAbove>()) { 1148 for (Region &subRegion : op.getRegions()) 1149 pendingRegions.push_back(&subRegion); 1150 } 1151 } 1152 } 1153 } 1154 1155 return success(); 1156 } 1157 1158 bool OpTrait::hasElementwiseMappableTraits(Operation *op) { 1159 return op->hasTrait<Elementwise>() && op->hasTrait<Scalarizable>() && 1160 op->hasTrait<Vectorizable>() && op->hasTrait<Tensorizable>(); 1161 } 1162 1163 //===----------------------------------------------------------------------===// 1164 // BinaryOp implementation 1165 //===----------------------------------------------------------------------===// 1166 1167 // These functions are out-of-line implementations of the methods in BinaryOp, 1168 // which avoids them being template instantiated/duplicated. 1169 1170 void impl::buildBinaryOp(OpBuilder &builder, OperationState &result, Value lhs, 1171 Value rhs) { 1172 assert(lhs.getType() == rhs.getType()); 1173 result.addOperands({lhs, rhs}); 1174 result.types.push_back(lhs.getType()); 1175 } 1176 1177 ParseResult impl::parseOneResultSameOperandTypeOp(OpAsmParser &parser, 1178 OperationState &result) { 1179 SmallVector<OpAsmParser::OperandType, 2> ops; 1180 Type type; 1181 // If the operand list is in-between parentheses, then we have a generic form. 1182 // (see the fallback in `printOneResultOp`). 1183 llvm::SMLoc loc = parser.getCurrentLocation(); 1184 if (!parser.parseOptionalLParen()) { 1185 if (parser.parseOperandList(ops) || parser.parseRParen() || 1186 parser.parseOptionalAttrDict(result.attributes) || 1187 parser.parseColon() || parser.parseType(type)) 1188 return failure(); 1189 auto fnType = type.dyn_cast<FunctionType>(); 1190 if (!fnType) { 1191 parser.emitError(loc, "expected function type"); 1192 return failure(); 1193 } 1194 if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands)) 1195 return failure(); 1196 result.addTypes(fnType.getResults()); 1197 return success(); 1198 } 1199 return failure(parser.parseOperandList(ops) || 1200 parser.parseOptionalAttrDict(result.attributes) || 1201 parser.parseColonType(type) || 1202 parser.resolveOperands(ops, type, result.operands) || 1203 parser.addTypeToList(type, result.types)); 1204 } 1205 1206 void impl::printOneResultOp(Operation *op, OpAsmPrinter &p) { 1207 assert(op->getNumResults() == 1 && "op should have one result"); 1208 1209 // If not all the operand and result types are the same, just use the 1210 // generic assembly form to avoid omitting information in printing. 1211 auto resultType = op->getResult(0).getType(); 1212 if (llvm::any_of(op->getOperandTypes(), 1213 [&](Type type) { return type != resultType; })) { 1214 p.printGenericOp(op, /*printOpName=*/false); 1215 return; 1216 } 1217 1218 p << ' '; 1219 p.printOperands(op->getOperands()); 1220 p.printOptionalAttrDict(op->getAttrs()); 1221 // Now we can output only one type for all operands and the result. 1222 p << " : " << resultType; 1223 } 1224 1225 //===----------------------------------------------------------------------===// 1226 // CastOp implementation 1227 //===----------------------------------------------------------------------===// 1228 1229 /// Attempt to fold the given cast operation. 1230 LogicalResult 1231 impl::foldCastInterfaceOp(Operation *op, ArrayRef<Attribute> attrOperands, 1232 SmallVectorImpl<OpFoldResult> &foldResults) { 1233 OperandRange operands = op->getOperands(); 1234 if (operands.empty()) 1235 return failure(); 1236 ResultRange results = op->getResults(); 1237 1238 // Check for the case where the input and output types match 1-1. 1239 if (operands.getTypes() == results.getTypes()) { 1240 foldResults.append(operands.begin(), operands.end()); 1241 return success(); 1242 } 1243 1244 return failure(); 1245 } 1246 1247 /// Attempt to verify the given cast operation. 1248 LogicalResult impl::verifyCastInterfaceOp( 1249 Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible) { 1250 auto resultTypes = op->getResultTypes(); 1251 if (llvm::empty(resultTypes)) 1252 return op->emitOpError() 1253 << "expected at least one result for cast operation"; 1254 1255 auto operandTypes = op->getOperandTypes(); 1256 if (!areCastCompatible(operandTypes, resultTypes)) { 1257 InFlightDiagnostic diag = op->emitOpError("operand type"); 1258 if (llvm::empty(operandTypes)) 1259 diag << "s []"; 1260 else if (llvm::size(operandTypes) == 1) 1261 diag << " " << *operandTypes.begin(); 1262 else 1263 diag << "s " << operandTypes; 1264 return diag << " and result type" << (resultTypes.size() == 1 ? " " : "s ") 1265 << resultTypes << " are cast incompatible"; 1266 } 1267 1268 return success(); 1269 } 1270 1271 void impl::buildCastOp(OpBuilder &builder, OperationState &result, Value source, 1272 Type destType) { 1273 result.addOperands(source); 1274 result.addTypes(destType); 1275 } 1276 1277 ParseResult impl::parseCastOp(OpAsmParser &parser, OperationState &result) { 1278 OpAsmParser::OperandType srcInfo; 1279 Type srcType, dstType; 1280 return failure(parser.parseOperand(srcInfo) || 1281 parser.parseOptionalAttrDict(result.attributes) || 1282 parser.parseColonType(srcType) || 1283 parser.resolveOperand(srcInfo, srcType, result.operands) || 1284 parser.parseKeywordType("to", dstType) || 1285 parser.addTypeToList(dstType, result.types)); 1286 } 1287 1288 void impl::printCastOp(Operation *op, OpAsmPrinter &p) { 1289 p << ' ' << op->getOperand(0); 1290 p.printOptionalAttrDict(op->getAttrs()); 1291 p << " : " << op->getOperand(0).getType() << " to " 1292 << op->getResult(0).getType(); 1293 } 1294 1295 Value impl::foldCastOp(Operation *op) { 1296 // Identity cast 1297 if (op->getOperand(0).getType() == op->getResult(0).getType()) 1298 return op->getOperand(0); 1299 return nullptr; 1300 } 1301 1302 LogicalResult 1303 impl::verifyCastOp(Operation *op, 1304 function_ref<bool(Type, Type)> areCastCompatible) { 1305 auto opType = op->getOperand(0).getType(); 1306 auto resType = op->getResult(0).getType(); 1307 if (!areCastCompatible(opType, resType)) 1308 return op->emitError("operand type ") 1309 << opType << " and result type " << resType 1310 << " are cast incompatible"; 1311 1312 return success(); 1313 } 1314 1315 //===----------------------------------------------------------------------===// 1316 // Misc. utils 1317 //===----------------------------------------------------------------------===// 1318 1319 /// Insert an operation, generated by `buildTerminatorOp`, at the end of the 1320 /// region's only block if it does not have a terminator already. If the region 1321 /// is empty, insert a new block first. `buildTerminatorOp` should return the 1322 /// terminator operation to insert. 1323 void impl::ensureRegionTerminator( 1324 Region ®ion, OpBuilder &builder, Location loc, 1325 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { 1326 OpBuilder::InsertionGuard guard(builder); 1327 if (region.empty()) 1328 builder.createBlock(®ion); 1329 1330 Block &block = region.back(); 1331 if (!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>()) 1332 return; 1333 1334 builder.setInsertionPointToEnd(&block); 1335 builder.insert(buildTerminatorOp(builder, loc)); 1336 } 1337 1338 /// Create a simple OpBuilder and forward to the OpBuilder version of this 1339 /// function. 1340 void impl::ensureRegionTerminator( 1341 Region ®ion, Builder &builder, Location loc, 1342 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { 1343 OpBuilder opBuilder(builder.getContext()); 1344 ensureRegionTerminator(region, opBuilder, loc, buildTerminatorOp); 1345 } 1346