1 //===- Operation.cpp - Operation support code -----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Operation.h" 10 #include "mlir/IR/BlockAndValueMapping.h" 11 #include "mlir/IR/BuiltinTypes.h" 12 #include "mlir/IR/Dialect.h" 13 #include "mlir/IR/OpImplementation.h" 14 #include "mlir/IR/PatternMatch.h" 15 #include "mlir/IR/TypeUtilities.h" 16 #include "mlir/Interfaces/FoldInterfaces.h" 17 #include "llvm/ADT/StringExtras.h" 18 #include <numeric> 19 20 using namespace mlir; 21 22 //===----------------------------------------------------------------------===// 23 // OperationName 24 //===----------------------------------------------------------------------===// 25 26 /// Form the OperationName for an op with the specified string. This either is 27 /// a reference to an AbstractOperation if one is known, or a uniqued Identifier 28 /// if not. 29 OperationName::OperationName(StringRef name, MLIRContext *context) { 30 if (auto *op = AbstractOperation::lookup(name, context)) 31 representation = op; 32 else 33 representation = StringAttr::get(name, context); 34 } 35 36 /// Return the name of the dialect this operation is registered to. 37 StringRef OperationName::getDialectNamespace() const { 38 if (Dialect *dialect = getDialect()) 39 return dialect->getNamespace(); 40 return getStringRef().split('.').first; 41 } 42 43 /// Return the operation name with dialect name stripped, if it has one. 44 StringRef OperationName::stripDialect() const { 45 return getStringRef().split('.').second; 46 } 47 48 /// Return the name of this operation. This always succeeds. 49 StringRef OperationName::getStringRef() const { 50 return getIdentifier().strref(); 51 } 52 53 /// Return the name of this operation as an identifier. This always succeeds. 54 StringAttr OperationName::getIdentifier() const { 55 if (auto *op = representation.dyn_cast<const AbstractOperation *>()) 56 return op->name; 57 return representation.get<StringAttr>(); 58 } 59 60 OperationName OperationName::getFromOpaquePointer(const void *pointer) { 61 return OperationName( 62 RepresentationUnion::getFromOpaqueValue(const_cast<void *>(pointer))); 63 } 64 65 //===----------------------------------------------------------------------===// 66 // Operation 67 //===----------------------------------------------------------------------===// 68 69 /// Create a new Operation with the specific fields. 70 Operation *Operation::create(Location location, OperationName name, 71 TypeRange resultTypes, ValueRange operands, 72 ArrayRef<NamedAttribute> attributes, 73 BlockRange successors, unsigned numRegions) { 74 return create(location, name, resultTypes, operands, 75 DictionaryAttr::get(location.getContext(), attributes), 76 successors, numRegions); 77 } 78 79 /// Create a new Operation from operation state. 80 Operation *Operation::create(const OperationState &state) { 81 return create(state.location, state.name, state.types, state.operands, 82 state.attributes.getDictionary(state.getContext()), 83 state.successors, state.regions); 84 } 85 86 /// Create a new Operation with the specific fields. 87 Operation *Operation::create(Location location, OperationName name, 88 TypeRange resultTypes, ValueRange operands, 89 DictionaryAttr attributes, BlockRange successors, 90 RegionRange regions) { 91 unsigned numRegions = regions.size(); 92 Operation *op = create(location, name, resultTypes, operands, attributes, 93 successors, numRegions); 94 for (unsigned i = 0; i < numRegions; ++i) 95 if (regions[i]) 96 op->getRegion(i).takeBody(*regions[i]); 97 return op; 98 } 99 100 /// Overload of create that takes an existing DictionaryAttr to avoid 101 /// unnecessarily uniquing a list of attributes. 102 Operation *Operation::create(Location location, OperationName name, 103 TypeRange resultTypes, ValueRange operands, 104 DictionaryAttr attributes, BlockRange successors, 105 unsigned numRegions) { 106 assert(llvm::all_of(resultTypes, [](Type t) { return t; }) && 107 "unexpected null result type"); 108 109 // We only need to allocate additional memory for a subset of results. 110 unsigned numTrailingResults = OpResult::getNumTrailing(resultTypes.size()); 111 unsigned numInlineResults = OpResult::getNumInline(resultTypes.size()); 112 unsigned numSuccessors = successors.size(); 113 unsigned numOperands = operands.size(); 114 unsigned numResults = resultTypes.size(); 115 116 // If the operation is known to have no operands, don't allocate an operand 117 // storage. 118 bool needsOperandStorage = true; 119 if (operands.empty()) { 120 if (const AbstractOperation *abstractOp = name.getAbstractOperation()) 121 needsOperandStorage = !abstractOp->hasTrait<OpTrait::ZeroOperands>(); 122 } 123 124 // Compute the byte size for the operation and the operand storage. This takes 125 // into account the size of the operation, its trailing objects, and its 126 // prefixed objects. 127 size_t byteSize = 128 totalSizeToAlloc<detail::OperandStorage, BlockOperand, Region, OpOperand>( 129 needsOperandStorage ? 1 : 0, numSuccessors, numRegions, numOperands); 130 size_t prefixByteSize = llvm::alignTo( 131 Operation::prefixAllocSize(numTrailingResults, numInlineResults), 132 alignof(Operation)); 133 char *mallocMem = reinterpret_cast<char *>(malloc(byteSize + prefixByteSize)); 134 void *rawMem = mallocMem + prefixByteSize; 135 136 // Create the new Operation. 137 Operation *op = 138 ::new (rawMem) Operation(location, name, numResults, numSuccessors, 139 numRegions, attributes, needsOperandStorage); 140 141 assert((numSuccessors == 0 || op->mightHaveTrait<OpTrait::IsTerminator>()) && 142 "unexpected successors in a non-terminator operation"); 143 144 // Initialize the results. 145 auto resultTypeIt = resultTypes.begin(); 146 for (unsigned i = 0; i < numInlineResults; ++i, ++resultTypeIt) 147 new (op->getInlineOpResult(i)) detail::InlineOpResult(*resultTypeIt, i); 148 for (unsigned i = 0; i < numTrailingResults; ++i, ++resultTypeIt) { 149 new (op->getOutOfLineOpResult(i)) 150 detail::OutOfLineOpResult(*resultTypeIt, i); 151 } 152 153 // Initialize the regions. 154 for (unsigned i = 0; i != numRegions; ++i) 155 new (&op->getRegion(i)) Region(op); 156 157 // Initialize the operands. 158 if (needsOperandStorage) { 159 new (&op->getOperandStorage()) detail::OperandStorage( 160 op, op->getTrailingObjects<OpOperand>(), operands); 161 } 162 163 // Initialize the successors. 164 auto blockOperands = op->getBlockOperands(); 165 for (unsigned i = 0; i != numSuccessors; ++i) 166 new (&blockOperands[i]) BlockOperand(op, successors[i]); 167 168 return op; 169 } 170 171 Operation::Operation(Location location, OperationName name, unsigned numResults, 172 unsigned numSuccessors, unsigned numRegions, 173 DictionaryAttr attributes, bool hasOperandStorage) 174 : location(location), numResults(numResults), numSuccs(numSuccessors), 175 numRegions(numRegions), hasOperandStorage(hasOperandStorage), name(name), 176 attrs(attributes) { 177 assert(attributes && "unexpected null attribute dictionary"); 178 #ifndef NDEBUG 179 if (!getDialect() && !getContext()->allowsUnregisteredDialects()) 180 llvm::report_fatal_error( 181 name.getStringRef() + 182 " created with unregistered dialect. If this is intended, please call " 183 "allowUnregisteredDialects() on the MLIRContext, or use " 184 "-allow-unregistered-dialect with the MLIR tool used."); 185 #endif 186 } 187 188 // Operations are deleted through the destroy() member because they are 189 // allocated via malloc. 190 Operation::~Operation() { 191 assert(block == nullptr && "operation destroyed but still in a block"); 192 #ifndef NDEBUG 193 if (!use_empty()) { 194 { 195 InFlightDiagnostic diag = 196 emitOpError("operation destroyed but still has uses"); 197 for (Operation *user : getUsers()) 198 diag.attachNote(user->getLoc()) << "- use: " << *user << "\n"; 199 } 200 llvm::report_fatal_error("operation destroyed but still has uses"); 201 } 202 #endif 203 // Explicitly run the destructors for the operands. 204 if (hasOperandStorage) 205 getOperandStorage().~OperandStorage(); 206 207 // Explicitly run the destructors for the successors. 208 for (auto &successor : getBlockOperands()) 209 successor.~BlockOperand(); 210 211 // Explicitly destroy the regions. 212 for (auto ®ion : getRegions()) 213 region.~Region(); 214 } 215 216 /// Destroy this operation or one of its subclasses. 217 void Operation::destroy() { 218 // Operations may have additional prefixed allocation, which needs to be 219 // accounted for here when computing the address to free. 220 char *rawMem = reinterpret_cast<char *>(this) - 221 llvm::alignTo(prefixAllocSize(), alignof(Operation)); 222 this->~Operation(); 223 free(rawMem); 224 } 225 226 /// Return true if this operation is a proper ancestor of the `other` 227 /// operation. 228 bool Operation::isProperAncestor(Operation *other) { 229 while ((other = other->getParentOp())) 230 if (this == other) 231 return true; 232 return false; 233 } 234 235 /// Replace any uses of 'from' with 'to' within this operation. 236 void Operation::replaceUsesOfWith(Value from, Value to) { 237 if (from == to) 238 return; 239 for (auto &operand : getOpOperands()) 240 if (operand.get() == from) 241 operand.set(to); 242 } 243 244 /// Replace the current operands of this operation with the ones provided in 245 /// 'operands'. 246 void Operation::setOperands(ValueRange operands) { 247 if (LLVM_LIKELY(hasOperandStorage)) 248 return getOperandStorage().setOperands(this, operands); 249 assert(operands.empty() && "setting operands without an operand storage"); 250 } 251 252 /// Replace the operands beginning at 'start' and ending at 'start' + 'length' 253 /// with the ones provided in 'operands'. 'operands' may be smaller or larger 254 /// than the range pointed to by 'start'+'length'. 255 void Operation::setOperands(unsigned start, unsigned length, 256 ValueRange operands) { 257 assert((start + length) <= getNumOperands() && 258 "invalid operand range specified"); 259 if (LLVM_LIKELY(hasOperandStorage)) 260 return getOperandStorage().setOperands(this, start, length, operands); 261 assert(operands.empty() && "setting operands without an operand storage"); 262 } 263 264 /// Insert the given operands into the operand list at the given 'index'. 265 void Operation::insertOperands(unsigned index, ValueRange operands) { 266 if (LLVM_LIKELY(hasOperandStorage)) 267 return setOperands(index, /*length=*/0, operands); 268 assert(operands.empty() && "inserting operands without an operand storage"); 269 } 270 271 //===----------------------------------------------------------------------===// 272 // Diagnostics 273 //===----------------------------------------------------------------------===// 274 275 /// Emit an error about fatal conditions with this operation, reporting up to 276 /// any diagnostic handlers that may be listening. 277 InFlightDiagnostic Operation::emitError(const Twine &message) { 278 InFlightDiagnostic diag = mlir::emitError(getLoc(), message); 279 if (getContext()->shouldPrintOpOnDiagnostic()) { 280 diag.attachNote(getLoc()) 281 .append("see current operation: ") 282 .appendOp(*this, OpPrintingFlags().printGenericOpForm()); 283 } 284 return diag; 285 } 286 287 /// Emit a warning about this operation, reporting up to any diagnostic 288 /// handlers that may be listening. 289 InFlightDiagnostic Operation::emitWarning(const Twine &message) { 290 InFlightDiagnostic diag = mlir::emitWarning(getLoc(), message); 291 if (getContext()->shouldPrintOpOnDiagnostic()) 292 diag.attachNote(getLoc()) << "see current operation: " << *this; 293 return diag; 294 } 295 296 /// Emit a remark about this operation, reporting up to any diagnostic 297 /// handlers that may be listening. 298 InFlightDiagnostic Operation::emitRemark(const Twine &message) { 299 InFlightDiagnostic diag = mlir::emitRemark(getLoc(), message); 300 if (getContext()->shouldPrintOpOnDiagnostic()) 301 diag.attachNote(getLoc()) << "see current operation: " << *this; 302 return diag; 303 } 304 305 //===----------------------------------------------------------------------===// 306 // Operation Ordering 307 //===----------------------------------------------------------------------===// 308 309 constexpr unsigned Operation::kInvalidOrderIdx; 310 constexpr unsigned Operation::kOrderStride; 311 312 /// Given an operation 'other' that is within the same parent block, return 313 /// whether the current operation is before 'other' in the operation list 314 /// of the parent block. 315 /// Note: This function has an average complexity of O(1), but worst case may 316 /// take O(N) where N is the number of operations within the parent block. 317 bool Operation::isBeforeInBlock(Operation *other) { 318 assert(block && "Operations without parent blocks have no order."); 319 assert(other && other->block == block && 320 "Expected other operation to have the same parent block."); 321 // If the order of the block is already invalid, directly recompute the 322 // parent. 323 if (!block->isOpOrderValid()) { 324 block->recomputeOpOrder(); 325 } else { 326 // Update the order either operation if necessary. 327 updateOrderIfNecessary(); 328 other->updateOrderIfNecessary(); 329 } 330 331 return orderIndex < other->orderIndex; 332 } 333 334 /// Update the order index of this operation of this operation if necessary, 335 /// potentially recomputing the order of the parent block. 336 void Operation::updateOrderIfNecessary() { 337 assert(block && "expected valid parent"); 338 339 // If the order is valid for this operation there is nothing to do. 340 if (hasValidOrder()) 341 return; 342 Operation *blockFront = &block->front(); 343 Operation *blockBack = &block->back(); 344 345 // This method is expected to only be invoked on blocks with more than one 346 // operation. 347 assert(blockFront != blockBack && "expected more than one operation"); 348 349 // If the operation is at the end of the block. 350 if (this == blockBack) { 351 Operation *prevNode = getPrevNode(); 352 if (!prevNode->hasValidOrder()) 353 return block->recomputeOpOrder(); 354 355 // Add the stride to the previous operation. 356 orderIndex = prevNode->orderIndex + kOrderStride; 357 return; 358 } 359 360 // If this is the first operation try to use the next operation to compute the 361 // ordering. 362 if (this == blockFront) { 363 Operation *nextNode = getNextNode(); 364 if (!nextNode->hasValidOrder()) 365 return block->recomputeOpOrder(); 366 // There is no order to give this operation. 367 if (nextNode->orderIndex == 0) 368 return block->recomputeOpOrder(); 369 370 // If we can't use the stride, just take the middle value left. This is safe 371 // because we know there is at least one valid index to assign to. 372 if (nextNode->orderIndex <= kOrderStride) 373 orderIndex = (nextNode->orderIndex / 2); 374 else 375 orderIndex = kOrderStride; 376 return; 377 } 378 379 // Otherwise, this operation is between two others. Place this operation in 380 // the middle of the previous and next if possible. 381 Operation *prevNode = getPrevNode(), *nextNode = getNextNode(); 382 if (!prevNode->hasValidOrder() || !nextNode->hasValidOrder()) 383 return block->recomputeOpOrder(); 384 unsigned prevOrder = prevNode->orderIndex, nextOrder = nextNode->orderIndex; 385 386 // Check to see if there is a valid order between the two. 387 if (prevOrder + 1 == nextOrder) 388 return block->recomputeOpOrder(); 389 orderIndex = prevOrder + ((nextOrder - prevOrder) / 2); 390 } 391 392 //===----------------------------------------------------------------------===// 393 // ilist_traits for Operation 394 //===----------------------------------------------------------------------===// 395 396 auto llvm::ilist_detail::SpecificNodeAccess< 397 typename llvm::ilist_detail::compute_node_options< 398 ::mlir::Operation>::type>::getNodePtr(pointer N) -> node_type * { 399 return NodeAccess::getNodePtr<OptionsT>(N); 400 } 401 402 auto llvm::ilist_detail::SpecificNodeAccess< 403 typename llvm::ilist_detail::compute_node_options< 404 ::mlir::Operation>::type>::getNodePtr(const_pointer N) 405 -> const node_type * { 406 return NodeAccess::getNodePtr<OptionsT>(N); 407 } 408 409 auto llvm::ilist_detail::SpecificNodeAccess< 410 typename llvm::ilist_detail::compute_node_options< 411 ::mlir::Operation>::type>::getValuePtr(node_type *N) -> pointer { 412 return NodeAccess::getValuePtr<OptionsT>(N); 413 } 414 415 auto llvm::ilist_detail::SpecificNodeAccess< 416 typename llvm::ilist_detail::compute_node_options< 417 ::mlir::Operation>::type>::getValuePtr(const node_type *N) 418 -> const_pointer { 419 return NodeAccess::getValuePtr<OptionsT>(N); 420 } 421 422 void llvm::ilist_traits<::mlir::Operation>::deleteNode(Operation *op) { 423 op->destroy(); 424 } 425 426 Block *llvm::ilist_traits<::mlir::Operation>::getContainingBlock() { 427 size_t Offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr)))); 428 iplist<Operation> *Anchor(static_cast<iplist<Operation> *>(this)); 429 return reinterpret_cast<Block *>(reinterpret_cast<char *>(Anchor) - Offset); 430 } 431 432 /// This is a trait method invoked when an operation is added to a block. We 433 /// keep the block pointer up to date. 434 void llvm::ilist_traits<::mlir::Operation>::addNodeToList(Operation *op) { 435 assert(!op->getBlock() && "already in an operation block!"); 436 op->block = getContainingBlock(); 437 438 // Invalidate the order on the operation. 439 op->orderIndex = Operation::kInvalidOrderIdx; 440 } 441 442 /// This is a trait method invoked when an operation is removed from a block. 443 /// We keep the block pointer up to date. 444 void llvm::ilist_traits<::mlir::Operation>::removeNodeFromList(Operation *op) { 445 assert(op->block && "not already in an operation block!"); 446 op->block = nullptr; 447 } 448 449 /// This is a trait method invoked when an operation is moved from one block 450 /// to another. We keep the block pointer up to date. 451 void llvm::ilist_traits<::mlir::Operation>::transferNodesFromList( 452 ilist_traits<Operation> &otherList, op_iterator first, op_iterator last) { 453 Block *curParent = getContainingBlock(); 454 455 // Invalidate the ordering of the parent block. 456 curParent->invalidateOpOrder(); 457 458 // If we are transferring operations within the same block, the block 459 // pointer doesn't need to be updated. 460 if (curParent == otherList.getContainingBlock()) 461 return; 462 463 // Update the 'block' member of each operation. 464 for (; first != last; ++first) 465 first->block = curParent; 466 } 467 468 /// Remove this operation (and its descendants) from its Block and delete 469 /// all of them. 470 void Operation::erase() { 471 if (auto *parent = getBlock()) 472 parent->getOperations().erase(this); 473 else 474 destroy(); 475 } 476 477 /// Remove the operation from its parent block, but don't delete it. 478 void Operation::remove() { 479 if (Block *parent = getBlock()) 480 parent->getOperations().remove(this); 481 } 482 483 /// Unlink this operation from its current block and insert it right before 484 /// `existingOp` which may be in the same or another block in the same 485 /// function. 486 void Operation::moveBefore(Operation *existingOp) { 487 moveBefore(existingOp->getBlock(), existingOp->getIterator()); 488 } 489 490 /// Unlink this operation from its current basic block and insert it right 491 /// before `iterator` in the specified basic block. 492 void Operation::moveBefore(Block *block, 493 llvm::iplist<Operation>::iterator iterator) { 494 block->getOperations().splice(iterator, getBlock()->getOperations(), 495 getIterator()); 496 } 497 498 /// Unlink this operation from its current block and insert it right after 499 /// `existingOp` which may be in the same or another block in the same function. 500 void Operation::moveAfter(Operation *existingOp) { 501 moveAfter(existingOp->getBlock(), existingOp->getIterator()); 502 } 503 504 /// Unlink this operation from its current block and insert it right after 505 /// `iterator` in the specified block. 506 void Operation::moveAfter(Block *block, 507 llvm::iplist<Operation>::iterator iterator) { 508 assert(iterator != block->end() && "cannot move after end of block"); 509 moveBefore(block, std::next(iterator)); 510 } 511 512 /// This drops all operand uses from this operation, which is an essential 513 /// step in breaking cyclic dependences between references when they are to 514 /// be deleted. 515 void Operation::dropAllReferences() { 516 for (auto &op : getOpOperands()) 517 op.drop(); 518 519 for (auto ®ion : getRegions()) 520 region.dropAllReferences(); 521 522 for (auto &dest : getBlockOperands()) 523 dest.drop(); 524 } 525 526 /// This drops all uses of any values defined by this operation or its nested 527 /// regions, wherever they are located. 528 void Operation::dropAllDefinedValueUses() { 529 dropAllUses(); 530 531 for (auto ®ion : getRegions()) 532 for (auto &block : region) 533 block.dropAllDefinedValueUses(); 534 } 535 536 void Operation::setSuccessor(Block *block, unsigned index) { 537 assert(index < getNumSuccessors()); 538 getBlockOperands()[index].set(block); 539 } 540 541 /// Attempt to fold this operation using the Op's registered foldHook. 542 LogicalResult Operation::fold(ArrayRef<Attribute> operands, 543 SmallVectorImpl<OpFoldResult> &results) { 544 // If we have a registered operation definition matching this one, use it to 545 // try to constant fold the operation. 546 auto *abstractOp = getAbstractOperation(); 547 if (abstractOp && succeeded(abstractOp->foldHook(this, operands, results))) 548 return success(); 549 550 // Otherwise, fall back on the dialect hook to handle it. 551 Dialect *dialect = getDialect(); 552 if (!dialect) 553 return failure(); 554 555 auto *interface = dialect->getRegisteredInterface<DialectFoldInterface>(); 556 if (!interface) 557 return failure(); 558 559 return interface->fold(this, operands, results); 560 } 561 562 /// Emit an error with the op name prefixed, like "'dim' op " which is 563 /// convenient for verifiers. 564 InFlightDiagnostic Operation::emitOpError(const Twine &message) { 565 return emitError() << "'" << getName() << "' op " << message; 566 } 567 568 //===----------------------------------------------------------------------===// 569 // Operation Cloning 570 //===----------------------------------------------------------------------===// 571 572 /// Create a deep copy of this operation but keep the operation regions empty. 573 /// Operands are remapped using `mapper` (if present), and `mapper` is updated 574 /// to contain the results. 575 Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper) { 576 SmallVector<Value, 8> operands; 577 SmallVector<Block *, 2> successors; 578 579 // Remap the operands. 580 operands.reserve(getNumOperands()); 581 for (auto opValue : getOperands()) 582 operands.push_back(mapper.lookupOrDefault(opValue)); 583 584 // Remap the successors. 585 successors.reserve(getNumSuccessors()); 586 for (Block *successor : getSuccessors()) 587 successors.push_back(mapper.lookupOrDefault(successor)); 588 589 // Create the new operation. 590 auto *newOp = create(getLoc(), getName(), getResultTypes(), operands, attrs, 591 successors, getNumRegions()); 592 593 // Remember the mapping of any results. 594 for (unsigned i = 0, e = getNumResults(); i != e; ++i) 595 mapper.map(getResult(i), newOp->getResult(i)); 596 597 return newOp; 598 } 599 600 Operation *Operation::cloneWithoutRegions() { 601 BlockAndValueMapping mapper; 602 return cloneWithoutRegions(mapper); 603 } 604 605 /// Create a deep copy of this operation, remapping any operands that use 606 /// values outside of the operation using the map that is provided (leaving 607 /// them alone if no entry is present). Replaces references to cloned 608 /// sub-operations to the corresponding operation that is copied, and adds 609 /// those mappings to the map. 610 Operation *Operation::clone(BlockAndValueMapping &mapper) { 611 auto *newOp = cloneWithoutRegions(mapper); 612 613 // Clone the regions. 614 for (unsigned i = 0; i != numRegions; ++i) 615 getRegion(i).cloneInto(&newOp->getRegion(i), mapper); 616 617 return newOp; 618 } 619 620 Operation *Operation::clone() { 621 BlockAndValueMapping mapper; 622 return clone(mapper); 623 } 624 625 //===----------------------------------------------------------------------===// 626 // OpState trait class. 627 //===----------------------------------------------------------------------===// 628 629 // The fallback for the parser is to reject the custom assembly form. 630 ParseResult OpState::parse(OpAsmParser &parser, OperationState &result) { 631 return parser.emitError(parser.getNameLoc(), "has no custom assembly form"); 632 } 633 634 // The fallback for the printer is to print in the generic assembly form. 635 void OpState::print(Operation *op, OpAsmPrinter &p) { p.printGenericOp(op); } 636 // The fallback for the printer is to print in the generic assembly form. 637 void OpState::printOpName(Operation *op, OpAsmPrinter &p, 638 StringRef defaultDialect) { 639 StringRef name = op->getName().getStringRef(); 640 if (name.startswith((defaultDialect + ".").str())) 641 name = name.drop_front(defaultDialect.size() + 1); 642 // TODO: remove this special case (and update test/IR/parser.mlir) 643 else if ((defaultDialect.empty() || defaultDialect == "builtin") && 644 name.startswith("std.")) 645 name = name.drop_front(4); 646 p.getStream() << name; 647 } 648 649 /// Emit an error about fatal conditions with this operation, reporting up to 650 /// any diagnostic handlers that may be listening. 651 InFlightDiagnostic OpState::emitError(const Twine &message) { 652 return getOperation()->emitError(message); 653 } 654 655 /// Emit an error with the op name prefixed, like "'dim' op " which is 656 /// convenient for verifiers. 657 InFlightDiagnostic OpState::emitOpError(const Twine &message) { 658 return getOperation()->emitOpError(message); 659 } 660 661 /// Emit a warning about this operation, reporting up to any diagnostic 662 /// handlers that may be listening. 663 InFlightDiagnostic OpState::emitWarning(const Twine &message) { 664 return getOperation()->emitWarning(message); 665 } 666 667 /// Emit a remark about this operation, reporting up to any diagnostic 668 /// handlers that may be listening. 669 InFlightDiagnostic OpState::emitRemark(const Twine &message) { 670 return getOperation()->emitRemark(message); 671 } 672 673 //===----------------------------------------------------------------------===// 674 // Op Trait implementations 675 //===----------------------------------------------------------------------===// 676 677 OpFoldResult OpTrait::impl::foldIdempotent(Operation *op) { 678 auto *argumentOp = op->getOperand(0).getDefiningOp(); 679 if (argumentOp && op->getName() == argumentOp->getName()) { 680 // Replace the outer operation output with the inner operation. 681 return op->getOperand(0); 682 } 683 684 return {}; 685 } 686 687 OpFoldResult OpTrait::impl::foldInvolution(Operation *op) { 688 auto *argumentOp = op->getOperand(0).getDefiningOp(); 689 if (argumentOp && op->getName() == argumentOp->getName()) { 690 // Replace the outer involutions output with inner's input. 691 return argumentOp->getOperand(0); 692 } 693 694 return {}; 695 } 696 697 LogicalResult OpTrait::impl::verifyZeroOperands(Operation *op) { 698 if (op->getNumOperands() != 0) 699 return op->emitOpError() << "requires zero operands"; 700 return success(); 701 } 702 703 LogicalResult OpTrait::impl::verifyOneOperand(Operation *op) { 704 if (op->getNumOperands() != 1) 705 return op->emitOpError() << "requires a single operand"; 706 return success(); 707 } 708 709 LogicalResult OpTrait::impl::verifyNOperands(Operation *op, 710 unsigned numOperands) { 711 if (op->getNumOperands() != numOperands) { 712 return op->emitOpError() << "expected " << numOperands 713 << " operands, but found " << op->getNumOperands(); 714 } 715 return success(); 716 } 717 718 LogicalResult OpTrait::impl::verifyAtLeastNOperands(Operation *op, 719 unsigned numOperands) { 720 if (op->getNumOperands() < numOperands) 721 return op->emitOpError() 722 << "expected " << numOperands << " or more operands, but found " 723 << op->getNumOperands(); 724 return success(); 725 } 726 727 /// If this is a vector type, or a tensor type, return the scalar element type 728 /// that it is built around, otherwise return the type unmodified. 729 static Type getTensorOrVectorElementType(Type type) { 730 if (auto vec = type.dyn_cast<VectorType>()) 731 return vec.getElementType(); 732 733 // Look through tensor<vector<...>> to find the underlying element type. 734 if (auto tensor = type.dyn_cast<TensorType>()) 735 return getTensorOrVectorElementType(tensor.getElementType()); 736 return type; 737 } 738 739 LogicalResult OpTrait::impl::verifyIsIdempotent(Operation *op) { 740 // FIXME: Add back check for no side effects on operation. 741 // Currently adding it would cause the shared library build 742 // to fail since there would be a dependency of IR on SideEffectInterfaces 743 // which is cyclical. 744 return success(); 745 } 746 747 LogicalResult OpTrait::impl::verifyIsInvolution(Operation *op) { 748 // FIXME: Add back check for no side effects on operation. 749 // Currently adding it would cause the shared library build 750 // to fail since there would be a dependency of IR on SideEffectInterfaces 751 // which is cyclical. 752 return success(); 753 } 754 755 LogicalResult 756 OpTrait::impl::verifyOperandsAreSignlessIntegerLike(Operation *op) { 757 for (auto opType : op->getOperandTypes()) { 758 auto type = getTensorOrVectorElementType(opType); 759 if (!type.isSignlessIntOrIndex()) 760 return op->emitOpError() << "requires an integer or index type"; 761 } 762 return success(); 763 } 764 765 LogicalResult OpTrait::impl::verifyOperandsAreFloatLike(Operation *op) { 766 for (auto opType : op->getOperandTypes()) { 767 auto type = getTensorOrVectorElementType(opType); 768 if (!type.isa<FloatType>()) 769 return op->emitOpError("requires a float type"); 770 } 771 return success(); 772 } 773 774 LogicalResult OpTrait::impl::verifySameTypeOperands(Operation *op) { 775 // Zero or one operand always have the "same" type. 776 unsigned nOperands = op->getNumOperands(); 777 if (nOperands < 2) 778 return success(); 779 780 auto type = op->getOperand(0).getType(); 781 for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) 782 if (opType != type) 783 return op->emitOpError() << "requires all operands to have the same type"; 784 return success(); 785 } 786 787 LogicalResult OpTrait::impl::verifyZeroRegion(Operation *op) { 788 if (op->getNumRegions() != 0) 789 return op->emitOpError() << "requires zero regions"; 790 return success(); 791 } 792 793 LogicalResult OpTrait::impl::verifyOneRegion(Operation *op) { 794 if (op->getNumRegions() != 1) 795 return op->emitOpError() << "requires one region"; 796 return success(); 797 } 798 799 LogicalResult OpTrait::impl::verifyNRegions(Operation *op, 800 unsigned numRegions) { 801 if (op->getNumRegions() != numRegions) 802 return op->emitOpError() << "expected " << numRegions << " regions"; 803 return success(); 804 } 805 806 LogicalResult OpTrait::impl::verifyAtLeastNRegions(Operation *op, 807 unsigned numRegions) { 808 if (op->getNumRegions() < numRegions) 809 return op->emitOpError() << "expected " << numRegions << " or more regions"; 810 return success(); 811 } 812 813 LogicalResult OpTrait::impl::verifyZeroResult(Operation *op) { 814 if (op->getNumResults() != 0) 815 return op->emitOpError() << "requires zero results"; 816 return success(); 817 } 818 819 LogicalResult OpTrait::impl::verifyOneResult(Operation *op) { 820 if (op->getNumResults() != 1) 821 return op->emitOpError() << "requires one result"; 822 return success(); 823 } 824 825 LogicalResult OpTrait::impl::verifyNResults(Operation *op, 826 unsigned numOperands) { 827 if (op->getNumResults() != numOperands) 828 return op->emitOpError() << "expected " << numOperands << " results"; 829 return success(); 830 } 831 832 LogicalResult OpTrait::impl::verifyAtLeastNResults(Operation *op, 833 unsigned numOperands) { 834 if (op->getNumResults() < numOperands) 835 return op->emitOpError() 836 << "expected " << numOperands << " or more results"; 837 return success(); 838 } 839 840 LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) { 841 if (failed(verifyAtLeastNOperands(op, 1))) 842 return failure(); 843 844 if (failed(verifyCompatibleShapes(op->getOperandTypes()))) 845 return op->emitOpError() << "requires the same shape for all operands"; 846 847 return success(); 848 } 849 850 LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) { 851 if (failed(verifyAtLeastNOperands(op, 1)) || 852 failed(verifyAtLeastNResults(op, 1))) 853 return failure(); 854 855 SmallVector<Type, 8> types(op->getOperandTypes()); 856 types.append(llvm::to_vector<4>(op->getResultTypes())); 857 858 if (failed(verifyCompatibleShapes(types))) 859 return op->emitOpError() 860 << "requires the same shape for all operands and results"; 861 862 return success(); 863 } 864 865 LogicalResult OpTrait::impl::verifySameOperandsElementType(Operation *op) { 866 if (failed(verifyAtLeastNOperands(op, 1))) 867 return failure(); 868 auto elementType = getElementTypeOrSelf(op->getOperand(0)); 869 870 for (auto operand : llvm::drop_begin(op->getOperands(), 1)) { 871 if (getElementTypeOrSelf(operand) != elementType) 872 return op->emitOpError("requires the same element type for all operands"); 873 } 874 875 return success(); 876 } 877 878 LogicalResult 879 OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) { 880 if (failed(verifyAtLeastNOperands(op, 1)) || 881 failed(verifyAtLeastNResults(op, 1))) 882 return failure(); 883 884 auto elementType = getElementTypeOrSelf(op->getResult(0)); 885 886 // Verify result element type matches first result's element type. 887 for (auto result : llvm::drop_begin(op->getResults(), 1)) { 888 if (getElementTypeOrSelf(result) != elementType) 889 return op->emitOpError( 890 "requires the same element type for all operands and results"); 891 } 892 893 // Verify operand's element type matches first result's element type. 894 for (auto operand : op->getOperands()) { 895 if (getElementTypeOrSelf(operand) != elementType) 896 return op->emitOpError( 897 "requires the same element type for all operands and results"); 898 } 899 900 return success(); 901 } 902 903 LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) { 904 if (failed(verifyAtLeastNOperands(op, 1)) || 905 failed(verifyAtLeastNResults(op, 1))) 906 return failure(); 907 908 auto type = op->getResult(0).getType(); 909 auto elementType = getElementTypeOrSelf(type); 910 for (auto resultType : llvm::drop_begin(op->getResultTypes())) { 911 if (getElementTypeOrSelf(resultType) != elementType || 912 failed(verifyCompatibleShape(resultType, type))) 913 return op->emitOpError() 914 << "requires the same type for all operands and results"; 915 } 916 for (auto opType : op->getOperandTypes()) { 917 if (getElementTypeOrSelf(opType) != elementType || 918 failed(verifyCompatibleShape(opType, type))) 919 return op->emitOpError() 920 << "requires the same type for all operands and results"; 921 } 922 return success(); 923 } 924 925 LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) { 926 Block *block = op->getBlock(); 927 // Verify that the operation is at the end of the respective parent block. 928 if (!block || &block->back() != op) 929 return op->emitOpError("must be the last operation in the parent block"); 930 return success(); 931 } 932 933 static LogicalResult verifyTerminatorSuccessors(Operation *op) { 934 auto *parent = op->getParentRegion(); 935 936 // Verify that the operands lines up with the BB arguments in the successor. 937 for (Block *succ : op->getSuccessors()) 938 if (succ->getParent() != parent) 939 return op->emitError("reference to block defined in another region"); 940 return success(); 941 } 942 943 LogicalResult OpTrait::impl::verifyZeroSuccessor(Operation *op) { 944 if (op->getNumSuccessors() != 0) { 945 return op->emitOpError("requires 0 successors but found ") 946 << op->getNumSuccessors(); 947 } 948 return success(); 949 } 950 951 LogicalResult OpTrait::impl::verifyOneSuccessor(Operation *op) { 952 if (op->getNumSuccessors() != 1) { 953 return op->emitOpError("requires 1 successor but found ") 954 << op->getNumSuccessors(); 955 } 956 return verifyTerminatorSuccessors(op); 957 } 958 LogicalResult OpTrait::impl::verifyNSuccessors(Operation *op, 959 unsigned numSuccessors) { 960 if (op->getNumSuccessors() != numSuccessors) { 961 return op->emitOpError("requires ") 962 << numSuccessors << " successors but found " 963 << op->getNumSuccessors(); 964 } 965 return verifyTerminatorSuccessors(op); 966 } 967 LogicalResult OpTrait::impl::verifyAtLeastNSuccessors(Operation *op, 968 unsigned numSuccessors) { 969 if (op->getNumSuccessors() < numSuccessors) { 970 return op->emitOpError("requires at least ") 971 << numSuccessors << " successors but found " 972 << op->getNumSuccessors(); 973 } 974 return verifyTerminatorSuccessors(op); 975 } 976 977 LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) { 978 for (auto resultType : op->getResultTypes()) { 979 auto elementType = getTensorOrVectorElementType(resultType); 980 bool isBoolType = elementType.isInteger(1); 981 if (!isBoolType) 982 return op->emitOpError() << "requires a bool result type"; 983 } 984 985 return success(); 986 } 987 988 LogicalResult OpTrait::impl::verifyResultsAreFloatLike(Operation *op) { 989 for (auto resultType : op->getResultTypes()) 990 if (!getTensorOrVectorElementType(resultType).isa<FloatType>()) 991 return op->emitOpError() << "requires a floating point type"; 992 993 return success(); 994 } 995 996 LogicalResult 997 OpTrait::impl::verifyResultsAreSignlessIntegerLike(Operation *op) { 998 for (auto resultType : op->getResultTypes()) 999 if (!getTensorOrVectorElementType(resultType).isSignlessIntOrIndex()) 1000 return op->emitOpError() << "requires an integer or index type"; 1001 return success(); 1002 } 1003 1004 LogicalResult OpTrait::impl::verifyValueSizeAttr(Operation *op, 1005 StringRef attrName, 1006 StringRef valueGroupName, 1007 size_t expectedCount) { 1008 auto sizeAttr = op->getAttrOfType<DenseIntElementsAttr>(attrName); 1009 if (!sizeAttr) 1010 return op->emitOpError("requires 1D i32 elements attribute '") 1011 << attrName << "'"; 1012 1013 auto sizeAttrType = sizeAttr.getType(); 1014 if (sizeAttrType.getRank() != 1 || 1015 !sizeAttrType.getElementType().isInteger(32)) 1016 return op->emitOpError("requires 1D i32 elements attribute '") 1017 << attrName << "'"; 1018 1019 if (llvm::any_of(sizeAttr.getValues<APInt>(), [](const APInt &element) { 1020 return !element.isNonNegative(); 1021 })) 1022 return op->emitOpError("'") 1023 << attrName << "' attribute cannot have negative elements"; 1024 1025 size_t totalCount = std::accumulate( 1026 sizeAttr.begin(), sizeAttr.end(), 0, 1027 [](unsigned all, APInt one) { return all + one.getZExtValue(); }); 1028 1029 if (totalCount != expectedCount) 1030 return op->emitOpError() 1031 << valueGroupName << " count (" << expectedCount 1032 << ") does not match with the total size (" << totalCount 1033 << ") specified in attribute '" << attrName << "'"; 1034 return success(); 1035 } 1036 1037 LogicalResult OpTrait::impl::verifyOperandSizeAttr(Operation *op, 1038 StringRef attrName) { 1039 return verifyValueSizeAttr(op, attrName, "operand", op->getNumOperands()); 1040 } 1041 1042 LogicalResult OpTrait::impl::verifyResultSizeAttr(Operation *op, 1043 StringRef attrName) { 1044 return verifyValueSizeAttr(op, attrName, "result", op->getNumResults()); 1045 } 1046 1047 LogicalResult OpTrait::impl::verifyNoRegionArguments(Operation *op) { 1048 for (Region ®ion : op->getRegions()) { 1049 if (region.empty()) 1050 continue; 1051 1052 if (region.getNumArguments() != 0) { 1053 if (op->getNumRegions() > 1) 1054 return op->emitOpError("region #") 1055 << region.getRegionNumber() << " should have no arguments"; 1056 else 1057 return op->emitOpError("region should have no arguments"); 1058 } 1059 } 1060 return success(); 1061 } 1062 1063 LogicalResult OpTrait::impl::verifyElementwise(Operation *op) { 1064 auto isMappableType = [](Type type) { 1065 return type.isa<VectorType, TensorType>(); 1066 }; 1067 auto resultMappableTypes = llvm::to_vector<1>( 1068 llvm::make_filter_range(op->getResultTypes(), isMappableType)); 1069 auto operandMappableTypes = llvm::to_vector<2>( 1070 llvm::make_filter_range(op->getOperandTypes(), isMappableType)); 1071 1072 // If the op only has scalar operand/result types, then we have nothing to 1073 // check. 1074 if (resultMappableTypes.empty() && operandMappableTypes.empty()) 1075 return success(); 1076 1077 if (!resultMappableTypes.empty() && operandMappableTypes.empty()) 1078 return op->emitOpError("if a result is non-scalar, then at least one " 1079 "operand must be non-scalar"); 1080 1081 assert(!operandMappableTypes.empty()); 1082 1083 if (resultMappableTypes.empty()) 1084 return op->emitOpError("if an operand is non-scalar, then there must be at " 1085 "least one non-scalar result"); 1086 1087 if (resultMappableTypes.size() != op->getNumResults()) 1088 return op->emitOpError( 1089 "if an operand is non-scalar, then all results must be non-scalar"); 1090 1091 SmallVector<Type, 4> types = llvm::to_vector<2>( 1092 llvm::concat<Type>(operandMappableTypes, resultMappableTypes)); 1093 TypeID expectedBaseTy = types.front().getTypeID(); 1094 if (!llvm::all_of(types, 1095 [&](Type t) { return t.getTypeID() == expectedBaseTy; }) || 1096 failed(verifyCompatibleShapes(types))) { 1097 return op->emitOpError() << "all non-scalar operands/results must have the " 1098 "same shape and base type"; 1099 } 1100 1101 return success(); 1102 } 1103 1104 /// Check for any values used by operations regions attached to the 1105 /// specified "IsIsolatedFromAbove" operation defined outside of it. 1106 LogicalResult OpTrait::impl::verifyIsIsolatedFromAbove(Operation *isolatedOp) { 1107 assert(isolatedOp->hasTrait<OpTrait::IsIsolatedFromAbove>() && 1108 "Intended to check IsolatedFromAbove ops"); 1109 1110 // List of regions to analyze. Each region is processed independently, with 1111 // respect to the common `limit` region, so we can look at them in any order. 1112 // Therefore, use a simple vector and push/pop back the current region. 1113 SmallVector<Region *, 8> pendingRegions; 1114 for (auto ®ion : isolatedOp->getRegions()) { 1115 pendingRegions.push_back(®ion); 1116 1117 // Traverse all operations in the region. 1118 while (!pendingRegions.empty()) { 1119 for (Operation &op : pendingRegions.pop_back_val()->getOps()) { 1120 for (Value operand : op.getOperands()) { 1121 // operand should be non-null here if the IR is well-formed. But 1122 // we don't assert here as this function is called from the verifier 1123 // and so could be called on invalid IR. 1124 if (!operand) 1125 return op.emitOpError("operation's operand is null"); 1126 1127 // Check that any value that is used by an operation is defined in the 1128 // same region as either an operation result. 1129 auto *operandRegion = operand.getParentRegion(); 1130 if (!region.isAncestor(operandRegion)) { 1131 return op.emitOpError("using value defined outside the region") 1132 .attachNote(isolatedOp->getLoc()) 1133 << "required by region isolation constraints"; 1134 } 1135 } 1136 1137 // Schedule any regions in the operation for further checking. Don't 1138 // recurse into other IsolatedFromAbove ops, because they will check 1139 // themselves. 1140 if (op.getNumRegions() && 1141 !op.hasTrait<OpTrait::IsIsolatedFromAbove>()) { 1142 for (Region &subRegion : op.getRegions()) 1143 pendingRegions.push_back(&subRegion); 1144 } 1145 } 1146 } 1147 } 1148 1149 return success(); 1150 } 1151 1152 bool OpTrait::hasElementwiseMappableTraits(Operation *op) { 1153 return op->hasTrait<Elementwise>() && op->hasTrait<Scalarizable>() && 1154 op->hasTrait<Vectorizable>() && op->hasTrait<Tensorizable>(); 1155 } 1156 1157 //===----------------------------------------------------------------------===// 1158 // BinaryOp implementation 1159 //===----------------------------------------------------------------------===// 1160 1161 // These functions are out-of-line implementations of the methods in BinaryOp, 1162 // which avoids them being template instantiated/duplicated. 1163 1164 void impl::buildBinaryOp(OpBuilder &builder, OperationState &result, Value lhs, 1165 Value rhs) { 1166 assert(lhs.getType() == rhs.getType()); 1167 result.addOperands({lhs, rhs}); 1168 result.types.push_back(lhs.getType()); 1169 } 1170 1171 ParseResult impl::parseOneResultSameOperandTypeOp(OpAsmParser &parser, 1172 OperationState &result) { 1173 SmallVector<OpAsmParser::OperandType, 2> ops; 1174 Type type; 1175 // If the operand list is in-between parentheses, then we have a generic form. 1176 // (see the fallback in `printOneResultOp`). 1177 llvm::SMLoc loc = parser.getCurrentLocation(); 1178 if (!parser.parseOptionalLParen()) { 1179 if (parser.parseOperandList(ops) || parser.parseRParen() || 1180 parser.parseOptionalAttrDict(result.attributes) || 1181 parser.parseColon() || parser.parseType(type)) 1182 return failure(); 1183 auto fnType = type.dyn_cast<FunctionType>(); 1184 if (!fnType) { 1185 parser.emitError(loc, "expected function type"); 1186 return failure(); 1187 } 1188 if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands)) 1189 return failure(); 1190 result.addTypes(fnType.getResults()); 1191 return success(); 1192 } 1193 return failure(parser.parseOperandList(ops) || 1194 parser.parseOptionalAttrDict(result.attributes) || 1195 parser.parseColonType(type) || 1196 parser.resolveOperands(ops, type, result.operands) || 1197 parser.addTypeToList(type, result.types)); 1198 } 1199 1200 void impl::printOneResultOp(Operation *op, OpAsmPrinter &p) { 1201 assert(op->getNumResults() == 1 && "op should have one result"); 1202 1203 // If not all the operand and result types are the same, just use the 1204 // generic assembly form to avoid omitting information in printing. 1205 auto resultType = op->getResult(0).getType(); 1206 if (llvm::any_of(op->getOperandTypes(), 1207 [&](Type type) { return type != resultType; })) { 1208 p.printGenericOp(op, /*printOpName=*/false); 1209 return; 1210 } 1211 1212 p << ' '; 1213 p.printOperands(op->getOperands()); 1214 p.printOptionalAttrDict(op->getAttrs()); 1215 // Now we can output only one type for all operands and the result. 1216 p << " : " << resultType; 1217 } 1218 1219 //===----------------------------------------------------------------------===// 1220 // CastOp implementation 1221 //===----------------------------------------------------------------------===// 1222 1223 /// Attempt to fold the given cast operation. 1224 LogicalResult 1225 impl::foldCastInterfaceOp(Operation *op, ArrayRef<Attribute> attrOperands, 1226 SmallVectorImpl<OpFoldResult> &foldResults) { 1227 OperandRange operands = op->getOperands(); 1228 if (operands.empty()) 1229 return failure(); 1230 ResultRange results = op->getResults(); 1231 1232 // Check for the case where the input and output types match 1-1. 1233 if (operands.getTypes() == results.getTypes()) { 1234 foldResults.append(operands.begin(), operands.end()); 1235 return success(); 1236 } 1237 1238 return failure(); 1239 } 1240 1241 /// Attempt to verify the given cast operation. 1242 LogicalResult impl::verifyCastInterfaceOp( 1243 Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible) { 1244 auto resultTypes = op->getResultTypes(); 1245 if (llvm::empty(resultTypes)) 1246 return op->emitOpError() 1247 << "expected at least one result for cast operation"; 1248 1249 auto operandTypes = op->getOperandTypes(); 1250 if (!areCastCompatible(operandTypes, resultTypes)) { 1251 InFlightDiagnostic diag = op->emitOpError("operand type"); 1252 if (llvm::empty(operandTypes)) 1253 diag << "s []"; 1254 else if (llvm::size(operandTypes) == 1) 1255 diag << " " << *operandTypes.begin(); 1256 else 1257 diag << "s " << operandTypes; 1258 return diag << " and result type" << (resultTypes.size() == 1 ? " " : "s ") 1259 << resultTypes << " are cast incompatible"; 1260 } 1261 1262 return success(); 1263 } 1264 1265 void impl::buildCastOp(OpBuilder &builder, OperationState &result, Value source, 1266 Type destType) { 1267 result.addOperands(source); 1268 result.addTypes(destType); 1269 } 1270 1271 ParseResult impl::parseCastOp(OpAsmParser &parser, OperationState &result) { 1272 OpAsmParser::OperandType srcInfo; 1273 Type srcType, dstType; 1274 return failure(parser.parseOperand(srcInfo) || 1275 parser.parseOptionalAttrDict(result.attributes) || 1276 parser.parseColonType(srcType) || 1277 parser.resolveOperand(srcInfo, srcType, result.operands) || 1278 parser.parseKeywordType("to", dstType) || 1279 parser.addTypeToList(dstType, result.types)); 1280 } 1281 1282 void impl::printCastOp(Operation *op, OpAsmPrinter &p) { 1283 p << ' ' << op->getOperand(0); 1284 p.printOptionalAttrDict(op->getAttrs()); 1285 p << " : " << op->getOperand(0).getType() << " to " 1286 << op->getResult(0).getType(); 1287 } 1288 1289 Value impl::foldCastOp(Operation *op) { 1290 // Identity cast 1291 if (op->getOperand(0).getType() == op->getResult(0).getType()) 1292 return op->getOperand(0); 1293 return nullptr; 1294 } 1295 1296 LogicalResult 1297 impl::verifyCastOp(Operation *op, 1298 function_ref<bool(Type, Type)> areCastCompatible) { 1299 auto opType = op->getOperand(0).getType(); 1300 auto resType = op->getResult(0).getType(); 1301 if (!areCastCompatible(opType, resType)) 1302 return op->emitError("operand type ") 1303 << opType << " and result type " << resType 1304 << " are cast incompatible"; 1305 1306 return success(); 1307 } 1308 1309 //===----------------------------------------------------------------------===// 1310 // Misc. utils 1311 //===----------------------------------------------------------------------===// 1312 1313 /// Insert an operation, generated by `buildTerminatorOp`, at the end of the 1314 /// region's only block if it does not have a terminator already. If the region 1315 /// is empty, insert a new block first. `buildTerminatorOp` should return the 1316 /// terminator operation to insert. 1317 void impl::ensureRegionTerminator( 1318 Region ®ion, OpBuilder &builder, Location loc, 1319 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { 1320 OpBuilder::InsertionGuard guard(builder); 1321 if (region.empty()) 1322 builder.createBlock(®ion); 1323 1324 Block &block = region.back(); 1325 if (!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>()) 1326 return; 1327 1328 builder.setInsertionPointToEnd(&block); 1329 builder.insert(buildTerminatorOp(builder, loc)); 1330 } 1331 1332 /// Create a simple OpBuilder and forward to the OpBuilder version of this 1333 /// function. 1334 void impl::ensureRegionTerminator( 1335 Region ®ion, Builder &builder, Location loc, 1336 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { 1337 OpBuilder opBuilder(builder.getContext()); 1338 ensureRegionTerminator(region, opBuilder, loc, buildTerminatorOp); 1339 } 1340