1 //===- Operation.cpp - Operation support code -----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Operation.h" 10 #include "mlir/IR/BlockAndValueMapping.h" 11 #include "mlir/IR/Dialect.h" 12 #include "mlir/IR/OpImplementation.h" 13 #include "mlir/IR/PatternMatch.h" 14 #include "mlir/IR/StandardTypes.h" 15 #include "mlir/IR/TypeUtilities.h" 16 #include "mlir/Interfaces/FoldInterfaces.h" 17 #include <numeric> 18 19 using namespace mlir; 20 21 OpAsmParser::~OpAsmParser() {} 22 23 //===----------------------------------------------------------------------===// 24 // OperationName 25 //===----------------------------------------------------------------------===// 26 27 /// Form the OperationName for an op with the specified string. This either is 28 /// a reference to an AbstractOperation if one is known, or a uniqued Identifier 29 /// if not. 30 OperationName::OperationName(StringRef name, MLIRContext *context) { 31 if (auto *op = AbstractOperation::lookup(name, context)) 32 representation = op; 33 else 34 representation = Identifier::get(name, context); 35 } 36 37 /// Return the name of the dialect this operation is registered to. 38 StringRef OperationName::getDialect() const { 39 return getStringRef().split('.').first; 40 } 41 42 /// Return the operation name with dialect name stripped, if it has one. 43 StringRef OperationName::stripDialect() const { 44 auto splitName = getStringRef().split("."); 45 return splitName.second.empty() ? splitName.first : splitName.second; 46 } 47 48 /// Return the name of this operation. This always succeeds. 49 StringRef OperationName::getStringRef() const { 50 return getIdentifier().strref(); 51 } 52 53 /// Return the name of this operation as an identifier. This always succeeds. 54 Identifier OperationName::getIdentifier() const { 55 if (auto *op = representation.dyn_cast<const AbstractOperation *>()) 56 return op->name; 57 return representation.get<Identifier>(); 58 } 59 60 const AbstractOperation *OperationName::getAbstractOperation() const { 61 return representation.dyn_cast<const AbstractOperation *>(); 62 } 63 64 OperationName OperationName::getFromOpaquePointer(void *pointer) { 65 return OperationName(RepresentationUnion::getFromOpaqueValue(pointer)); 66 } 67 68 //===----------------------------------------------------------------------===// 69 // Operation 70 //===----------------------------------------------------------------------===// 71 72 /// Create a new Operation with the specific fields. 73 Operation *Operation::create(Location location, OperationName name, 74 TypeRange resultTypes, ValueRange operands, 75 ArrayRef<NamedAttribute> attributes, 76 BlockRange successors, unsigned numRegions) { 77 return create(location, name, resultTypes, operands, 78 MutableDictionaryAttr(attributes), successors, numRegions); 79 } 80 81 /// Create a new Operation from operation state. 82 Operation *Operation::create(const OperationState &state) { 83 return create(state.location, state.name, state.types, state.operands, 84 state.attributes, state.successors, state.regions); 85 } 86 87 /// Create a new Operation with the specific fields. 88 Operation *Operation::create(Location location, OperationName name, 89 TypeRange resultTypes, ValueRange operands, 90 MutableDictionaryAttr attributes, 91 BlockRange successors, RegionRange regions) { 92 unsigned numRegions = regions.size(); 93 Operation *op = create(location, name, resultTypes, operands, attributes, 94 successors, numRegions); 95 for (unsigned i = 0; i < numRegions; ++i) 96 if (regions[i]) 97 op->getRegion(i).takeBody(*regions[i]); 98 return op; 99 } 100 101 /// Overload of create that takes an existing MutableDictionaryAttr to avoid 102 /// unnecessarily uniquing a list of attributes. 103 Operation *Operation::create(Location location, OperationName name, 104 TypeRange resultTypes, ValueRange operands, 105 MutableDictionaryAttr attributes, 106 BlockRange successors, unsigned numRegions) { 107 // We only need to allocate additional memory for a subset of results. 108 unsigned numTrailingResults = OpResult::getNumTrailing(resultTypes.size()); 109 unsigned numInlineResults = OpResult::getNumInline(resultTypes.size()); 110 unsigned numSuccessors = successors.size(); 111 unsigned numOperands = operands.size(); 112 113 // If the operation is known to have no operands, don't allocate an operand 114 // storage. 115 bool needsOperandStorage = true; 116 if (operands.empty()) { 117 if (const AbstractOperation *abstractOp = name.getAbstractOperation()) 118 needsOperandStorage = !abstractOp->hasTrait<OpTrait::ZeroOperands>(); 119 } 120 121 // Compute the byte size for the operation and the operand storage. 122 auto byteSize = 123 totalSizeToAlloc<detail::InLineOpResult, detail::TrailingOpResult, 124 BlockOperand, Region, detail::OperandStorage>( 125 numInlineResults, numTrailingResults, numSuccessors, numRegions, 126 needsOperandStorage ? 1 : 0); 127 byteSize += 128 llvm::alignTo(detail::OperandStorage::additionalAllocSize(numOperands), 129 alignof(Operation)); 130 void *rawMem = malloc(byteSize); 131 132 // Create the new Operation. 133 Operation *op = 134 ::new (rawMem) Operation(location, name, resultTypes, numSuccessors, 135 numRegions, attributes, needsOperandStorage); 136 137 assert((numSuccessors == 0 || !op->isKnownNonTerminator()) && 138 "unexpected successors in a non-terminator operation"); 139 140 // Initialize the results. 141 for (unsigned i = 0; i < numInlineResults; ++i) 142 new (op->getInlineResult(i)) detail::InLineOpResult(); 143 for (unsigned i = 0; i < numTrailingResults; ++i) 144 new (op->getTrailingResult(i)) detail::TrailingOpResult(i); 145 146 // Initialize the regions. 147 for (unsigned i = 0; i != numRegions; ++i) 148 new (&op->getRegion(i)) Region(op); 149 150 // Initialize the operands. 151 if (needsOperandStorage) 152 new (&op->getOperandStorage()) detail::OperandStorage(op, operands); 153 154 // Initialize the successors. 155 auto blockOperands = op->getBlockOperands(); 156 for (unsigned i = 0; i != numSuccessors; ++i) 157 new (&blockOperands[i]) BlockOperand(op, successors[i]); 158 159 return op; 160 } 161 162 Operation::Operation(Location location, OperationName name, 163 TypeRange resultTypes, unsigned numSuccessors, 164 unsigned numRegions, 165 const MutableDictionaryAttr &attributes, 166 bool hasOperandStorage) 167 : location(location), numSuccs(numSuccessors), numRegions(numRegions), 168 hasOperandStorage(hasOperandStorage), hasSingleResult(false), name(name), 169 attrs(attributes) { 170 if (!resultTypes.empty()) { 171 // If there is a single result it is stored in-place, otherwise use a tuple. 172 hasSingleResult = resultTypes.size() == 1; 173 if (hasSingleResult) 174 resultType = resultTypes.front(); 175 else 176 resultType = TupleType::get(resultTypes, location->getContext()); 177 } 178 } 179 180 // Operations are deleted through the destroy() member because they are 181 // allocated via malloc. 182 Operation::~Operation() { 183 assert(block == nullptr && "operation destroyed but still in a block"); 184 185 // Explicitly run the destructors for the operands. 186 if (hasOperandStorage) 187 getOperandStorage().~OperandStorage(); 188 189 // Explicitly run the destructors for the successors. 190 for (auto &successor : getBlockOperands()) 191 successor.~BlockOperand(); 192 193 // Explicitly destroy the regions. 194 for (auto ®ion : getRegions()) 195 region.~Region(); 196 } 197 198 /// Destroy this operation or one of its subclasses. 199 void Operation::destroy() { 200 this->~Operation(); 201 free(this); 202 } 203 204 /// Return the context this operation is associated with. 205 MLIRContext *Operation::getContext() { return location->getContext(); } 206 207 /// Return the dialect this operation is associated with, or nullptr if the 208 /// associated dialect is not registered. 209 Dialect *Operation::getDialect() { 210 if (auto *abstractOp = getAbstractOperation()) 211 return &abstractOp->dialect; 212 213 // If this operation hasn't been registered or doesn't have abstract 214 // operation, try looking up the dialect name in the context. 215 return getContext()->getLoadedDialect(getName().getDialect()); 216 } 217 218 Region *Operation::getParentRegion() { 219 return block ? block->getParent() : nullptr; 220 } 221 222 Operation *Operation::getParentOp() { 223 return block ? block->getParentOp() : nullptr; 224 } 225 226 /// Return true if this operation is a proper ancestor of the `other` 227 /// operation. 228 bool Operation::isProperAncestor(Operation *other) { 229 while ((other = other->getParentOp())) 230 if (this == other) 231 return true; 232 return false; 233 } 234 235 /// Replace any uses of 'from' with 'to' within this operation. 236 void Operation::replaceUsesOfWith(Value from, Value to) { 237 if (from == to) 238 return; 239 for (auto &operand : getOpOperands()) 240 if (operand.get() == from) 241 operand.set(to); 242 } 243 244 /// Replace the current operands of this operation with the ones provided in 245 /// 'operands'. 246 void Operation::setOperands(ValueRange operands) { 247 if (LLVM_LIKELY(hasOperandStorage)) 248 return getOperandStorage().setOperands(this, operands); 249 assert(operands.empty() && "setting operands without an operand storage"); 250 } 251 252 /// Replace the operands beginning at 'start' and ending at 'start' + 'length' 253 /// with the ones provided in 'operands'. 'operands' may be smaller or larger 254 /// than the range pointed to by 'start'+'length'. 255 void Operation::setOperands(unsigned start, unsigned length, 256 ValueRange operands) { 257 assert((start + length) <= getNumOperands() && 258 "invalid operand range specified"); 259 if (LLVM_LIKELY(hasOperandStorage)) 260 return getOperandStorage().setOperands(this, start, length, operands); 261 assert(operands.empty() && "setting operands without an operand storage"); 262 } 263 264 /// Insert the given operands into the operand list at the given 'index'. 265 void Operation::insertOperands(unsigned index, ValueRange operands) { 266 if (LLVM_LIKELY(hasOperandStorage)) 267 return setOperands(index, /*length=*/0, operands); 268 assert(operands.empty() && "inserting operands without an operand storage"); 269 } 270 271 //===----------------------------------------------------------------------===// 272 // Diagnostics 273 //===----------------------------------------------------------------------===// 274 275 /// Emit an error about fatal conditions with this operation, reporting up to 276 /// any diagnostic handlers that may be listening. 277 InFlightDiagnostic Operation::emitError(const Twine &message) { 278 InFlightDiagnostic diag = mlir::emitError(getLoc(), message); 279 if (getContext()->shouldPrintOpOnDiagnostic()) { 280 // Print out the operation explicitly here so that we can print the generic 281 // form. 282 // TODO: It would be nice if we could instead provide the 283 // specific printing flags when adding the operation as an argument to the 284 // diagnostic. 285 std::string printedOp; 286 { 287 llvm::raw_string_ostream os(printedOp); 288 print(os, OpPrintingFlags().printGenericOpForm().useLocalScope()); 289 } 290 diag.attachNote(getLoc()) << "see current operation: " << printedOp; 291 } 292 return diag; 293 } 294 295 /// Emit a warning about this operation, reporting up to any diagnostic 296 /// handlers that may be listening. 297 InFlightDiagnostic Operation::emitWarning(const Twine &message) { 298 InFlightDiagnostic diag = mlir::emitWarning(getLoc(), message); 299 if (getContext()->shouldPrintOpOnDiagnostic()) 300 diag.attachNote(getLoc()) << "see current operation: " << *this; 301 return diag; 302 } 303 304 /// Emit a remark about this operation, reporting up to any diagnostic 305 /// handlers that may be listening. 306 InFlightDiagnostic Operation::emitRemark(const Twine &message) { 307 InFlightDiagnostic diag = mlir::emitRemark(getLoc(), message); 308 if (getContext()->shouldPrintOpOnDiagnostic()) 309 diag.attachNote(getLoc()) << "see current operation: " << *this; 310 return diag; 311 } 312 313 //===----------------------------------------------------------------------===// 314 // Operation Ordering 315 //===----------------------------------------------------------------------===// 316 317 constexpr unsigned Operation::kInvalidOrderIdx; 318 constexpr unsigned Operation::kOrderStride; 319 320 /// Given an operation 'other' that is within the same parent block, return 321 /// whether the current operation is before 'other' in the operation list 322 /// of the parent block. 323 /// Note: This function has an average complexity of O(1), but worst case may 324 /// take O(N) where N is the number of operations within the parent block. 325 bool Operation::isBeforeInBlock(Operation *other) { 326 assert(block && "Operations without parent blocks have no order."); 327 assert(other && other->block == block && 328 "Expected other operation to have the same parent block."); 329 // If the order of the block is already invalid, directly recompute the 330 // parent. 331 if (!block->isOpOrderValid()) { 332 block->recomputeOpOrder(); 333 } else { 334 // Update the order either operation if necessary. 335 updateOrderIfNecessary(); 336 other->updateOrderIfNecessary(); 337 } 338 339 return orderIndex < other->orderIndex; 340 } 341 342 /// Update the order index of this operation of this operation if necessary, 343 /// potentially recomputing the order of the parent block. 344 void Operation::updateOrderIfNecessary() { 345 assert(block && "expected valid parent"); 346 347 // If the order is valid for this operation there is nothing to do. 348 if (hasValidOrder()) 349 return; 350 Operation *blockFront = &block->front(); 351 Operation *blockBack = &block->back(); 352 353 // This method is expected to only be invoked on blocks with more than one 354 // operation. 355 assert(blockFront != blockBack && "expected more than one operation"); 356 357 // If the operation is at the end of the block. 358 if (this == blockBack) { 359 Operation *prevNode = getPrevNode(); 360 if (!prevNode->hasValidOrder()) 361 return block->recomputeOpOrder(); 362 363 // Add the stride to the previous operation. 364 orderIndex = prevNode->orderIndex + kOrderStride; 365 return; 366 } 367 368 // If this is the first operation try to use the next operation to compute the 369 // ordering. 370 if (this == blockFront) { 371 Operation *nextNode = getNextNode(); 372 if (!nextNode->hasValidOrder()) 373 return block->recomputeOpOrder(); 374 // There is no order to give this operation. 375 if (nextNode->orderIndex == 0) 376 return block->recomputeOpOrder(); 377 378 // If we can't use the stride, just take the middle value left. This is safe 379 // because we know there is at least one valid index to assign to. 380 if (nextNode->orderIndex <= kOrderStride) 381 orderIndex = (nextNode->orderIndex / 2); 382 else 383 orderIndex = kOrderStride; 384 return; 385 } 386 387 // Otherwise, this operation is between two others. Place this operation in 388 // the middle of the previous and next if possible. 389 Operation *prevNode = getPrevNode(), *nextNode = getNextNode(); 390 if (!prevNode->hasValidOrder() || !nextNode->hasValidOrder()) 391 return block->recomputeOpOrder(); 392 unsigned prevOrder = prevNode->orderIndex, nextOrder = nextNode->orderIndex; 393 394 // Check to see if there is a valid order between the two. 395 if (prevOrder + 1 == nextOrder) 396 return block->recomputeOpOrder(); 397 orderIndex = prevOrder + ((nextOrder - prevOrder) / 2); 398 } 399 400 //===----------------------------------------------------------------------===// 401 // ilist_traits for Operation 402 //===----------------------------------------------------------------------===// 403 404 auto llvm::ilist_detail::SpecificNodeAccess< 405 typename llvm::ilist_detail::compute_node_options< 406 ::mlir::Operation>::type>::getNodePtr(pointer N) -> node_type * { 407 return NodeAccess::getNodePtr<OptionsT>(N); 408 } 409 410 auto llvm::ilist_detail::SpecificNodeAccess< 411 typename llvm::ilist_detail::compute_node_options< 412 ::mlir::Operation>::type>::getNodePtr(const_pointer N) 413 -> const node_type * { 414 return NodeAccess::getNodePtr<OptionsT>(N); 415 } 416 417 auto llvm::ilist_detail::SpecificNodeAccess< 418 typename llvm::ilist_detail::compute_node_options< 419 ::mlir::Operation>::type>::getValuePtr(node_type *N) -> pointer { 420 return NodeAccess::getValuePtr<OptionsT>(N); 421 } 422 423 auto llvm::ilist_detail::SpecificNodeAccess< 424 typename llvm::ilist_detail::compute_node_options< 425 ::mlir::Operation>::type>::getValuePtr(const node_type *N) 426 -> const_pointer { 427 return NodeAccess::getValuePtr<OptionsT>(N); 428 } 429 430 void llvm::ilist_traits<::mlir::Operation>::deleteNode(Operation *op) { 431 op->destroy(); 432 } 433 434 Block *llvm::ilist_traits<::mlir::Operation>::getContainingBlock() { 435 size_t Offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr)))); 436 iplist<Operation> *Anchor(static_cast<iplist<Operation> *>(this)); 437 return reinterpret_cast<Block *>(reinterpret_cast<char *>(Anchor) - Offset); 438 } 439 440 /// This is a trait method invoked when an operation is added to a block. We 441 /// keep the block pointer up to date. 442 void llvm::ilist_traits<::mlir::Operation>::addNodeToList(Operation *op) { 443 assert(!op->getBlock() && "already in an operation block!"); 444 op->block = getContainingBlock(); 445 446 // Invalidate the order on the operation. 447 op->orderIndex = Operation::kInvalidOrderIdx; 448 } 449 450 /// This is a trait method invoked when an operation is removed from a block. 451 /// We keep the block pointer up to date. 452 void llvm::ilist_traits<::mlir::Operation>::removeNodeFromList(Operation *op) { 453 assert(op->block && "not already in an operation block!"); 454 op->block = nullptr; 455 } 456 457 /// This is a trait method invoked when an operation is moved from one block 458 /// to another. We keep the block pointer up to date. 459 void llvm::ilist_traits<::mlir::Operation>::transferNodesFromList( 460 ilist_traits<Operation> &otherList, op_iterator first, op_iterator last) { 461 Block *curParent = getContainingBlock(); 462 463 // Invalidate the ordering of the parent block. 464 curParent->invalidateOpOrder(); 465 466 // If we are transferring operations within the same block, the block 467 // pointer doesn't need to be updated. 468 if (curParent == otherList.getContainingBlock()) 469 return; 470 471 // Update the 'block' member of each operation. 472 for (; first != last; ++first) 473 first->block = curParent; 474 } 475 476 /// Remove this operation (and its descendants) from its Block and delete 477 /// all of them. 478 void Operation::erase() { 479 if (auto *parent = getBlock()) 480 parent->getOperations().erase(this); 481 else 482 destroy(); 483 } 484 485 /// Unlink this operation from its current block and insert it right before 486 /// `existingOp` which may be in the same or another block in the same 487 /// function. 488 void Operation::moveBefore(Operation *existingOp) { 489 moveBefore(existingOp->getBlock(), existingOp->getIterator()); 490 } 491 492 /// Unlink this operation from its current basic block and insert it right 493 /// before `iterator` in the specified basic block. 494 void Operation::moveBefore(Block *block, 495 llvm::iplist<Operation>::iterator iterator) { 496 block->getOperations().splice(iterator, getBlock()->getOperations(), 497 getIterator()); 498 } 499 500 /// Unlink this operation from its current block and insert it right after 501 /// `existingOp` which may be in the same or another block in the same function. 502 void Operation::moveAfter(Operation *existingOp) { 503 moveAfter(existingOp->getBlock(), existingOp->getIterator()); 504 } 505 506 /// Unlink this operation from its current block and insert it right after 507 /// `iterator` in the specified block. 508 void Operation::moveAfter(Block *block, 509 llvm::iplist<Operation>::iterator iterator) { 510 assert(iterator != block->end() && "cannot move after end of block"); 511 moveBefore(&*std::next(iterator)); 512 } 513 514 /// This drops all operand uses from this operation, which is an essential 515 /// step in breaking cyclic dependences between references when they are to 516 /// be deleted. 517 void Operation::dropAllReferences() { 518 for (auto &op : getOpOperands()) 519 op.drop(); 520 521 for (auto ®ion : getRegions()) 522 region.dropAllReferences(); 523 524 for (auto &dest : getBlockOperands()) 525 dest.drop(); 526 } 527 528 /// This drops all uses of any values defined by this operation or its nested 529 /// regions, wherever they are located. 530 void Operation::dropAllDefinedValueUses() { 531 dropAllUses(); 532 533 for (auto ®ion : getRegions()) 534 for (auto &block : region) 535 block.dropAllDefinedValueUses(); 536 } 537 538 /// Return the number of results held by this operation. 539 unsigned Operation::getNumResults() { 540 if (!resultType) 541 return 0; 542 return hasSingleResult ? 1 : resultType.cast<TupleType>().size(); 543 } 544 545 auto Operation::getResultTypes() -> result_type_range { 546 if (!resultType) 547 return llvm::None; 548 if (hasSingleResult) 549 return resultType; 550 return resultType.cast<TupleType>().getTypes(); 551 } 552 553 void Operation::setSuccessor(Block *block, unsigned index) { 554 assert(index < getNumSuccessors()); 555 getBlockOperands()[index].set(block); 556 } 557 558 /// Attempt to fold this operation using the Op's registered foldHook. 559 LogicalResult Operation::fold(ArrayRef<Attribute> operands, 560 SmallVectorImpl<OpFoldResult> &results) { 561 // If we have a registered operation definition matching this one, use it to 562 // try to constant fold the operation. 563 auto *abstractOp = getAbstractOperation(); 564 if (abstractOp && succeeded(abstractOp->foldHook(this, operands, results))) 565 return success(); 566 567 // Otherwise, fall back on the dialect hook to handle it. 568 Dialect *dialect = getDialect(); 569 if (!dialect) 570 return failure(); 571 572 auto *interface = dialect->getRegisteredInterface<DialectFoldInterface>(); 573 if (!interface) 574 return failure(); 575 576 return interface->fold(this, operands, results); 577 } 578 579 /// Emit an error with the op name prefixed, like "'dim' op " which is 580 /// convenient for verifiers. 581 InFlightDiagnostic Operation::emitOpError(const Twine &message) { 582 return emitError() << "'" << getName() << "' op " << message; 583 } 584 585 //===----------------------------------------------------------------------===// 586 // Operation Cloning 587 //===----------------------------------------------------------------------===// 588 589 /// Create a deep copy of this operation but keep the operation regions empty. 590 /// Operands are remapped using `mapper` (if present), and `mapper` is updated 591 /// to contain the results. 592 Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper) { 593 SmallVector<Value, 8> operands; 594 SmallVector<Block *, 2> successors; 595 596 // Remap the operands. 597 operands.reserve(getNumOperands()); 598 for (auto opValue : getOperands()) 599 operands.push_back(mapper.lookupOrDefault(opValue)); 600 601 // Remap the successors. 602 successors.reserve(getNumSuccessors()); 603 for (Block *successor : getSuccessors()) 604 successors.push_back(mapper.lookupOrDefault(successor)); 605 606 // Create the new operation. 607 auto *newOp = create(getLoc(), getName(), getResultTypes(), operands, attrs, 608 successors, getNumRegions()); 609 610 // Remember the mapping of any results. 611 for (unsigned i = 0, e = getNumResults(); i != e; ++i) 612 mapper.map(getResult(i), newOp->getResult(i)); 613 614 return newOp; 615 } 616 617 Operation *Operation::cloneWithoutRegions() { 618 BlockAndValueMapping mapper; 619 return cloneWithoutRegions(mapper); 620 } 621 622 /// Create a deep copy of this operation, remapping any operands that use 623 /// values outside of the operation using the map that is provided (leaving 624 /// them alone if no entry is present). Replaces references to cloned 625 /// sub-operations to the corresponding operation that is copied, and adds 626 /// those mappings to the map. 627 Operation *Operation::clone(BlockAndValueMapping &mapper) { 628 auto *newOp = cloneWithoutRegions(mapper); 629 630 // Clone the regions. 631 for (unsigned i = 0; i != numRegions; ++i) 632 getRegion(i).cloneInto(&newOp->getRegion(i), mapper); 633 634 return newOp; 635 } 636 637 Operation *Operation::clone() { 638 BlockAndValueMapping mapper; 639 return clone(mapper); 640 } 641 642 //===----------------------------------------------------------------------===// 643 // OpState trait class. 644 //===----------------------------------------------------------------------===// 645 646 // The fallback for the parser is to reject the custom assembly form. 647 ParseResult OpState::parse(OpAsmParser &parser, OperationState &result) { 648 return parser.emitError(parser.getNameLoc(), "has no custom assembly form"); 649 } 650 651 // The fallback for the printer is to print in the generic assembly form. 652 void OpState::print(OpAsmPrinter &p) { p.printGenericOp(getOperation()); } 653 654 /// Emit an error about fatal conditions with this operation, reporting up to 655 /// any diagnostic handlers that may be listening. 656 InFlightDiagnostic OpState::emitError(const Twine &message) { 657 return getOperation()->emitError(message); 658 } 659 660 /// Emit an error with the op name prefixed, like "'dim' op " which is 661 /// convenient for verifiers. 662 InFlightDiagnostic OpState::emitOpError(const Twine &message) { 663 return getOperation()->emitOpError(message); 664 } 665 666 /// Emit a warning about this operation, reporting up to any diagnostic 667 /// handlers that may be listening. 668 InFlightDiagnostic OpState::emitWarning(const Twine &message) { 669 return getOperation()->emitWarning(message); 670 } 671 672 /// Emit a remark about this operation, reporting up to any diagnostic 673 /// handlers that may be listening. 674 InFlightDiagnostic OpState::emitRemark(const Twine &message) { 675 return getOperation()->emitRemark(message); 676 } 677 678 //===----------------------------------------------------------------------===// 679 // Op Trait implementations 680 //===----------------------------------------------------------------------===// 681 682 OpFoldResult OpTrait::impl::foldIdempotent(Operation *op) { 683 auto *argumentOp = op->getOperand(0).getDefiningOp(); 684 if (argumentOp && op->getName() == argumentOp->getName()) { 685 // Replace the outer operation output with the inner operation. 686 return op->getOperand(0); 687 } 688 689 return {}; 690 } 691 692 OpFoldResult OpTrait::impl::foldInvolution(Operation *op) { 693 auto *argumentOp = op->getOperand(0).getDefiningOp(); 694 if (argumentOp && op->getName() == argumentOp->getName()) { 695 // Replace the outer involutions output with inner's input. 696 return argumentOp->getOperand(0); 697 } 698 699 return {}; 700 } 701 702 LogicalResult OpTrait::impl::verifyZeroOperands(Operation *op) { 703 if (op->getNumOperands() != 0) 704 return op->emitOpError() << "requires zero operands"; 705 return success(); 706 } 707 708 LogicalResult OpTrait::impl::verifyOneOperand(Operation *op) { 709 if (op->getNumOperands() != 1) 710 return op->emitOpError() << "requires a single operand"; 711 return success(); 712 } 713 714 LogicalResult OpTrait::impl::verifyNOperands(Operation *op, 715 unsigned numOperands) { 716 if (op->getNumOperands() != numOperands) { 717 return op->emitOpError() << "expected " << numOperands 718 << " operands, but found " << op->getNumOperands(); 719 } 720 return success(); 721 } 722 723 LogicalResult OpTrait::impl::verifyAtLeastNOperands(Operation *op, 724 unsigned numOperands) { 725 if (op->getNumOperands() < numOperands) 726 return op->emitOpError() 727 << "expected " << numOperands << " or more operands"; 728 return success(); 729 } 730 731 /// If this is a vector type, or a tensor type, return the scalar element type 732 /// that it is built around, otherwise return the type unmodified. 733 static Type getTensorOrVectorElementType(Type type) { 734 if (auto vec = type.dyn_cast<VectorType>()) 735 return vec.getElementType(); 736 737 // Look through tensor<vector<...>> to find the underlying element type. 738 if (auto tensor = type.dyn_cast<TensorType>()) 739 return getTensorOrVectorElementType(tensor.getElementType()); 740 return type; 741 } 742 743 LogicalResult OpTrait::impl::verifyIsIdempotent(Operation *op) { 744 // FIXME: Add back check for no side effects on operation. 745 // Currently adding it would cause the shared library build 746 // to fail since there would be a dependency of IR on SideEffectInterfaces 747 // which is cyclical. 748 return success(); 749 } 750 751 LogicalResult OpTrait::impl::verifyIsInvolution(Operation *op) { 752 // FIXME: Add back check for no side effects on operation. 753 // Currently adding it would cause the shared library build 754 // to fail since there would be a dependency of IR on SideEffectInterfaces 755 // which is cyclical. 756 return success(); 757 } 758 759 LogicalResult 760 OpTrait::impl::verifyOperandsAreSignlessIntegerLike(Operation *op) { 761 for (auto opType : op->getOperandTypes()) { 762 auto type = getTensorOrVectorElementType(opType); 763 if (!type.isSignlessIntOrIndex()) 764 return op->emitOpError() << "requires an integer or index type"; 765 } 766 return success(); 767 } 768 769 LogicalResult OpTrait::impl::verifyOperandsAreFloatLike(Operation *op) { 770 for (auto opType : op->getOperandTypes()) { 771 auto type = getTensorOrVectorElementType(opType); 772 if (!type.isa<FloatType>()) 773 return op->emitOpError("requires a float type"); 774 } 775 return success(); 776 } 777 778 LogicalResult OpTrait::impl::verifySameTypeOperands(Operation *op) { 779 // Zero or one operand always have the "same" type. 780 unsigned nOperands = op->getNumOperands(); 781 if (nOperands < 2) 782 return success(); 783 784 auto type = op->getOperand(0).getType(); 785 for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) 786 if (opType != type) 787 return op->emitOpError() << "requires all operands to have the same type"; 788 return success(); 789 } 790 791 LogicalResult OpTrait::impl::verifyZeroRegion(Operation *op) { 792 if (op->getNumRegions() != 0) 793 return op->emitOpError() << "requires zero regions"; 794 return success(); 795 } 796 797 LogicalResult OpTrait::impl::verifyOneRegion(Operation *op) { 798 if (op->getNumRegions() != 1) 799 return op->emitOpError() << "requires one region"; 800 return success(); 801 } 802 803 LogicalResult OpTrait::impl::verifyNRegions(Operation *op, 804 unsigned numRegions) { 805 if (op->getNumRegions() != numRegions) 806 return op->emitOpError() << "expected " << numRegions << " regions"; 807 return success(); 808 } 809 810 LogicalResult OpTrait::impl::verifyAtLeastNRegions(Operation *op, 811 unsigned numRegions) { 812 if (op->getNumRegions() < numRegions) 813 return op->emitOpError() << "expected " << numRegions << " or more regions"; 814 return success(); 815 } 816 817 LogicalResult OpTrait::impl::verifyZeroResult(Operation *op) { 818 if (op->getNumResults() != 0) 819 return op->emitOpError() << "requires zero results"; 820 return success(); 821 } 822 823 LogicalResult OpTrait::impl::verifyOneResult(Operation *op) { 824 if (op->getNumResults() != 1) 825 return op->emitOpError() << "requires one result"; 826 return success(); 827 } 828 829 LogicalResult OpTrait::impl::verifyNResults(Operation *op, 830 unsigned numOperands) { 831 if (op->getNumResults() != numOperands) 832 return op->emitOpError() << "expected " << numOperands << " results"; 833 return success(); 834 } 835 836 LogicalResult OpTrait::impl::verifyAtLeastNResults(Operation *op, 837 unsigned numOperands) { 838 if (op->getNumResults() < numOperands) 839 return op->emitOpError() 840 << "expected " << numOperands << " or more results"; 841 return success(); 842 } 843 844 LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) { 845 if (failed(verifyAtLeastNOperands(op, 1))) 846 return failure(); 847 848 auto type = op->getOperand(0).getType(); 849 for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) { 850 if (failed(verifyCompatibleShape(opType, type))) 851 return op->emitOpError() << "requires the same shape for all operands"; 852 } 853 return success(); 854 } 855 856 LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) { 857 if (failed(verifyAtLeastNOperands(op, 1)) || 858 failed(verifyAtLeastNResults(op, 1))) 859 return failure(); 860 861 auto type = op->getOperand(0).getType(); 862 for (auto resultType : op->getResultTypes()) { 863 if (failed(verifyCompatibleShape(resultType, type))) 864 return op->emitOpError() 865 << "requires the same shape for all operands and results"; 866 } 867 for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) { 868 if (failed(verifyCompatibleShape(opType, type))) 869 return op->emitOpError() 870 << "requires the same shape for all operands and results"; 871 } 872 return success(); 873 } 874 875 LogicalResult OpTrait::impl::verifySameOperandsElementType(Operation *op) { 876 if (failed(verifyAtLeastNOperands(op, 1))) 877 return failure(); 878 auto elementType = getElementTypeOrSelf(op->getOperand(0)); 879 880 for (auto operand : llvm::drop_begin(op->getOperands(), 1)) { 881 if (getElementTypeOrSelf(operand) != elementType) 882 return op->emitOpError("requires the same element type for all operands"); 883 } 884 885 return success(); 886 } 887 888 LogicalResult 889 OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) { 890 if (failed(verifyAtLeastNOperands(op, 1)) || 891 failed(verifyAtLeastNResults(op, 1))) 892 return failure(); 893 894 auto elementType = getElementTypeOrSelf(op->getResult(0)); 895 896 // Verify result element type matches first result's element type. 897 for (auto result : llvm::drop_begin(op->getResults(), 1)) { 898 if (getElementTypeOrSelf(result) != elementType) 899 return op->emitOpError( 900 "requires the same element type for all operands and results"); 901 } 902 903 // Verify operand's element type matches first result's element type. 904 for (auto operand : op->getOperands()) { 905 if (getElementTypeOrSelf(operand) != elementType) 906 return op->emitOpError( 907 "requires the same element type for all operands and results"); 908 } 909 910 return success(); 911 } 912 913 LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) { 914 if (failed(verifyAtLeastNOperands(op, 1)) || 915 failed(verifyAtLeastNResults(op, 1))) 916 return failure(); 917 918 auto type = op->getResult(0).getType(); 919 auto elementType = getElementTypeOrSelf(type); 920 for (auto resultType : op->getResultTypes().drop_front(1)) { 921 if (getElementTypeOrSelf(resultType) != elementType || 922 failed(verifyCompatibleShape(resultType, type))) 923 return op->emitOpError() 924 << "requires the same type for all operands and results"; 925 } 926 for (auto opType : op->getOperandTypes()) { 927 if (getElementTypeOrSelf(opType) != elementType || 928 failed(verifyCompatibleShape(opType, type))) 929 return op->emitOpError() 930 << "requires the same type for all operands and results"; 931 } 932 return success(); 933 } 934 935 LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) { 936 Block *block = op->getBlock(); 937 // Verify that the operation is at the end of the respective parent block. 938 if (!block || &block->back() != op) 939 return op->emitOpError("must be the last operation in the parent block"); 940 return success(); 941 } 942 943 static LogicalResult verifyTerminatorSuccessors(Operation *op) { 944 auto *parent = op->getParentRegion(); 945 946 // Verify that the operands lines up with the BB arguments in the successor. 947 for (Block *succ : op->getSuccessors()) 948 if (succ->getParent() != parent) 949 return op->emitError("reference to block defined in another region"); 950 return success(); 951 } 952 953 LogicalResult OpTrait::impl::verifyZeroSuccessor(Operation *op) { 954 if (op->getNumSuccessors() != 0) { 955 return op->emitOpError("requires 0 successors but found ") 956 << op->getNumSuccessors(); 957 } 958 return success(); 959 } 960 961 LogicalResult OpTrait::impl::verifyOneSuccessor(Operation *op) { 962 if (op->getNumSuccessors() != 1) { 963 return op->emitOpError("requires 1 successor but found ") 964 << op->getNumSuccessors(); 965 } 966 return verifyTerminatorSuccessors(op); 967 } 968 LogicalResult OpTrait::impl::verifyNSuccessors(Operation *op, 969 unsigned numSuccessors) { 970 if (op->getNumSuccessors() != numSuccessors) { 971 return op->emitOpError("requires ") 972 << numSuccessors << " successors but found " 973 << op->getNumSuccessors(); 974 } 975 return verifyTerminatorSuccessors(op); 976 } 977 LogicalResult OpTrait::impl::verifyAtLeastNSuccessors(Operation *op, 978 unsigned numSuccessors) { 979 if (op->getNumSuccessors() < numSuccessors) { 980 return op->emitOpError("requires at least ") 981 << numSuccessors << " successors but found " 982 << op->getNumSuccessors(); 983 } 984 return verifyTerminatorSuccessors(op); 985 } 986 987 LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) { 988 for (auto resultType : op->getResultTypes()) { 989 auto elementType = getTensorOrVectorElementType(resultType); 990 bool isBoolType = elementType.isInteger(1); 991 if (!isBoolType) 992 return op->emitOpError() << "requires a bool result type"; 993 } 994 995 return success(); 996 } 997 998 LogicalResult OpTrait::impl::verifyResultsAreFloatLike(Operation *op) { 999 for (auto resultType : op->getResultTypes()) 1000 if (!getTensorOrVectorElementType(resultType).isa<FloatType>()) 1001 return op->emitOpError() << "requires a floating point type"; 1002 1003 return success(); 1004 } 1005 1006 LogicalResult 1007 OpTrait::impl::verifyResultsAreSignlessIntegerLike(Operation *op) { 1008 for (auto resultType : op->getResultTypes()) 1009 if (!getTensorOrVectorElementType(resultType).isSignlessIntOrIndex()) 1010 return op->emitOpError() << "requires an integer or index type"; 1011 return success(); 1012 } 1013 1014 static LogicalResult verifyValueSizeAttr(Operation *op, StringRef attrName, 1015 bool isOperand) { 1016 auto sizeAttr = op->getAttrOfType<DenseIntElementsAttr>(attrName); 1017 if (!sizeAttr) 1018 return op->emitOpError("requires 1D vector attribute '") << attrName << "'"; 1019 1020 auto sizeAttrType = sizeAttr.getType().dyn_cast<VectorType>(); 1021 if (!sizeAttrType || sizeAttrType.getRank() != 1) 1022 return op->emitOpError("requires 1D vector attribute '") << attrName << "'"; 1023 1024 if (llvm::any_of(sizeAttr.getIntValues(), [](const APInt &element) { 1025 return !element.isNonNegative(); 1026 })) 1027 return op->emitOpError("'") 1028 << attrName << "' attribute cannot have negative elements"; 1029 1030 size_t totalCount = std::accumulate( 1031 sizeAttr.begin(), sizeAttr.end(), 0, 1032 [](unsigned all, APInt one) { return all + one.getZExtValue(); }); 1033 1034 if (isOperand && totalCount != op->getNumOperands()) 1035 return op->emitOpError("operand count (") 1036 << op->getNumOperands() << ") does not match with the total size (" 1037 << totalCount << ") specified in attribute '" << attrName << "'"; 1038 else if (!isOperand && totalCount != op->getNumResults()) 1039 return op->emitOpError("result count (") 1040 << op->getNumResults() << ") does not match with the total size (" 1041 << totalCount << ") specified in attribute '" << attrName << "'"; 1042 return success(); 1043 } 1044 1045 LogicalResult OpTrait::impl::verifyOperandSizeAttr(Operation *op, 1046 StringRef attrName) { 1047 return verifyValueSizeAttr(op, attrName, /*isOperand=*/true); 1048 } 1049 1050 LogicalResult OpTrait::impl::verifyResultSizeAttr(Operation *op, 1051 StringRef attrName) { 1052 return verifyValueSizeAttr(op, attrName, /*isOperand=*/false); 1053 } 1054 1055 LogicalResult OpTrait::impl::verifyNoRegionArguments(Operation *op) { 1056 for (Region ®ion : op->getRegions()) { 1057 if (region.empty()) 1058 continue; 1059 1060 if (region.getNumArguments() != 0) { 1061 if (op->getNumRegions() > 1) 1062 return op->emitOpError("region #") 1063 << region.getRegionNumber() << " should have no arguments"; 1064 else 1065 return op->emitOpError("region should have no arguments"); 1066 } 1067 } 1068 return success(); 1069 } 1070 1071 //===----------------------------------------------------------------------===// 1072 // BinaryOp implementation 1073 //===----------------------------------------------------------------------===// 1074 1075 // These functions are out-of-line implementations of the methods in BinaryOp, 1076 // which avoids them being template instantiated/duplicated. 1077 1078 void impl::buildBinaryOp(OpBuilder &builder, OperationState &result, Value lhs, 1079 Value rhs) { 1080 assert(lhs.getType() == rhs.getType()); 1081 result.addOperands({lhs, rhs}); 1082 result.types.push_back(lhs.getType()); 1083 } 1084 1085 ParseResult impl::parseOneResultSameOperandTypeOp(OpAsmParser &parser, 1086 OperationState &result) { 1087 SmallVector<OpAsmParser::OperandType, 2> ops; 1088 Type type; 1089 return failure(parser.parseOperandList(ops) || 1090 parser.parseOptionalAttrDict(result.attributes) || 1091 parser.parseColonType(type) || 1092 parser.resolveOperands(ops, type, result.operands) || 1093 parser.addTypeToList(type, result.types)); 1094 } 1095 1096 void impl::printOneResultOp(Operation *op, OpAsmPrinter &p) { 1097 assert(op->getNumResults() == 1 && "op should have one result"); 1098 1099 // If not all the operand and result types are the same, just use the 1100 // generic assembly form to avoid omitting information in printing. 1101 auto resultType = op->getResult(0).getType(); 1102 if (llvm::any_of(op->getOperandTypes(), 1103 [&](Type type) { return type != resultType; })) { 1104 p.printGenericOp(op); 1105 return; 1106 } 1107 1108 p << op->getName() << ' '; 1109 p.printOperands(op->getOperands()); 1110 p.printOptionalAttrDict(op->getAttrs()); 1111 // Now we can output only one type for all operands and the result. 1112 p << " : " << resultType; 1113 } 1114 1115 //===----------------------------------------------------------------------===// 1116 // CastOp implementation 1117 //===----------------------------------------------------------------------===// 1118 1119 void impl::buildCastOp(OpBuilder &builder, OperationState &result, Value source, 1120 Type destType) { 1121 result.addOperands(source); 1122 result.addTypes(destType); 1123 } 1124 1125 ParseResult impl::parseCastOp(OpAsmParser &parser, OperationState &result) { 1126 OpAsmParser::OperandType srcInfo; 1127 Type srcType, dstType; 1128 return failure(parser.parseOperand(srcInfo) || 1129 parser.parseOptionalAttrDict(result.attributes) || 1130 parser.parseColonType(srcType) || 1131 parser.resolveOperand(srcInfo, srcType, result.operands) || 1132 parser.parseKeywordType("to", dstType) || 1133 parser.addTypeToList(dstType, result.types)); 1134 } 1135 1136 void impl::printCastOp(Operation *op, OpAsmPrinter &p) { 1137 p << op->getName() << ' ' << op->getOperand(0); 1138 p.printOptionalAttrDict(op->getAttrs()); 1139 p << " : " << op->getOperand(0).getType() << " to " 1140 << op->getResult(0).getType(); 1141 } 1142 1143 Value impl::foldCastOp(Operation *op) { 1144 // Identity cast 1145 if (op->getOperand(0).getType() == op->getResult(0).getType()) 1146 return op->getOperand(0); 1147 return nullptr; 1148 } 1149 1150 //===----------------------------------------------------------------------===// 1151 // Misc. utils 1152 //===----------------------------------------------------------------------===// 1153 1154 /// Insert an operation, generated by `buildTerminatorOp`, at the end of the 1155 /// region's only block if it does not have a terminator already. If the region 1156 /// is empty, insert a new block first. `buildTerminatorOp` should return the 1157 /// terminator operation to insert. 1158 void impl::ensureRegionTerminator( 1159 Region ®ion, OpBuilder &builder, Location loc, 1160 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { 1161 OpBuilder::InsertionGuard guard(builder); 1162 if (region.empty()) 1163 builder.createBlock(®ion); 1164 1165 Block &block = region.back(); 1166 if (!block.empty() && block.back().isKnownTerminator()) 1167 return; 1168 1169 builder.setInsertionPointToEnd(&block); 1170 builder.insert(buildTerminatorOp(builder, loc)); 1171 } 1172 1173 /// Create a simple OpBuilder and forward to the OpBuilder version of this 1174 /// function. 1175 void impl::ensureRegionTerminator( 1176 Region ®ion, Builder &builder, Location loc, 1177 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { 1178 OpBuilder opBuilder(builder.getContext()); 1179 ensureRegionTerminator(region, opBuilder, loc, buildTerminatorOp); 1180 } 1181 1182 //===----------------------------------------------------------------------===// 1183 // UseIterator 1184 //===----------------------------------------------------------------------===// 1185 1186 Operation::UseIterator::UseIterator(Operation *op, bool end) 1187 : op(op), res(end ? op->result_end() : op->result_begin()) { 1188 // Only initialize current use if there are results/can be uses. 1189 if (op->getNumResults()) 1190 skipOverResultsWithNoUsers(); 1191 } 1192 1193 Operation::UseIterator &Operation::UseIterator::operator++() { 1194 // We increment over uses, if we reach the last use then move to next 1195 // result. 1196 if (use != (*res).use_end()) 1197 ++use; 1198 if (use == (*res).use_end()) { 1199 ++res; 1200 skipOverResultsWithNoUsers(); 1201 } 1202 return *this; 1203 } 1204 1205 void Operation::UseIterator::skipOverResultsWithNoUsers() { 1206 while (res != op->result_end() && (*res).use_empty()) 1207 ++res; 1208 1209 // If we are at the last result, then set use to first use of 1210 // first result (sentinel value used for end). 1211 if (res == op->result_end()) 1212 use = {}; 1213 else 1214 use = (*res).use_begin(); 1215 } 1216