1 //===- Operation.cpp - Operation support code -----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Operation.h" 10 #include "mlir/IR/BlockAndValueMapping.h" 11 #include "mlir/IR/Dialect.h" 12 #include "mlir/IR/OpImplementation.h" 13 #include "mlir/IR/PatternMatch.h" 14 #include "mlir/IR/StandardTypes.h" 15 #include "mlir/IR/TypeUtilities.h" 16 #include "mlir/Interfaces/FoldInterfaces.h" 17 #include <numeric> 18 19 using namespace mlir; 20 21 OpAsmParser::~OpAsmParser() {} 22 23 //===----------------------------------------------------------------------===// 24 // OperationName 25 //===----------------------------------------------------------------------===// 26 27 /// Form the OperationName for an op with the specified string. This either is 28 /// a reference to an AbstractOperation if one is known, or a uniqued Identifier 29 /// if not. 30 OperationName::OperationName(StringRef name, MLIRContext *context) { 31 if (auto *op = AbstractOperation::lookup(name, context)) 32 representation = op; 33 else 34 representation = Identifier::get(name, context); 35 } 36 37 /// Return the name of the dialect this operation is registered to. 38 StringRef OperationName::getDialect() const { 39 return getStringRef().split('.').first; 40 } 41 42 /// Return the operation name with dialect name stripped, if it has one. 43 StringRef OperationName::stripDialect() const { 44 auto splitName = getStringRef().split("."); 45 return splitName.second.empty() ? splitName.first : splitName.second; 46 } 47 48 /// Return the name of this operation. This always succeeds. 49 StringRef OperationName::getStringRef() const { 50 return getIdentifier().strref(); 51 } 52 53 /// Return the name of this operation as an identifier. This always succeeds. 54 Identifier OperationName::getIdentifier() const { 55 if (auto *op = representation.dyn_cast<const AbstractOperation *>()) 56 return op->name; 57 return representation.get<Identifier>(); 58 } 59 60 const AbstractOperation *OperationName::getAbstractOperation() const { 61 return representation.dyn_cast<const AbstractOperation *>(); 62 } 63 64 OperationName OperationName::getFromOpaquePointer(void *pointer) { 65 return OperationName(RepresentationUnion::getFromOpaqueValue(pointer)); 66 } 67 68 //===----------------------------------------------------------------------===// 69 // Operation 70 //===----------------------------------------------------------------------===// 71 72 /// Create a new Operation with the specific fields. 73 Operation *Operation::create(Location location, OperationName name, 74 TypeRange resultTypes, ValueRange operands, 75 ArrayRef<NamedAttribute> attributes, 76 BlockRange successors, unsigned numRegions) { 77 return create(location, name, resultTypes, operands, 78 MutableDictionaryAttr(attributes), successors, numRegions); 79 } 80 81 /// Create a new Operation from operation state. 82 Operation *Operation::create(const OperationState &state) { 83 return create(state.location, state.name, state.types, state.operands, 84 state.attributes, state.successors, state.regions); 85 } 86 87 /// Create a new Operation with the specific fields. 88 Operation *Operation::create(Location location, OperationName name, 89 TypeRange resultTypes, ValueRange operands, 90 MutableDictionaryAttr attributes, 91 BlockRange successors, RegionRange regions) { 92 unsigned numRegions = regions.size(); 93 Operation *op = create(location, name, resultTypes, operands, attributes, 94 successors, numRegions); 95 for (unsigned i = 0; i < numRegions; ++i) 96 if (regions[i]) 97 op->getRegion(i).takeBody(*regions[i]); 98 return op; 99 } 100 101 /// Overload of create that takes an existing MutableDictionaryAttr to avoid 102 /// unnecessarily uniquing a list of attributes. 103 Operation *Operation::create(Location location, OperationName name, 104 TypeRange resultTypes, ValueRange operands, 105 MutableDictionaryAttr attributes, 106 BlockRange successors, unsigned numRegions) { 107 // We only need to allocate additional memory for a subset of results. 108 unsigned numTrailingResults = OpResult::getNumTrailing(resultTypes.size()); 109 unsigned numInlineResults = OpResult::getNumInline(resultTypes.size()); 110 unsigned numSuccessors = successors.size(); 111 unsigned numOperands = operands.size(); 112 113 // If the operation is known to have no operands, don't allocate an operand 114 // storage. 115 bool needsOperandStorage = true; 116 if (operands.empty()) { 117 if (const AbstractOperation *abstractOp = name.getAbstractOperation()) 118 needsOperandStorage = !abstractOp->hasTrait<OpTrait::ZeroOperands>(); 119 } 120 121 // Compute the byte size for the operation and the operand storage. 122 auto byteSize = 123 totalSizeToAlloc<detail::InLineOpResult, detail::TrailingOpResult, 124 BlockOperand, Region, detail::OperandStorage>( 125 numInlineResults, numTrailingResults, numSuccessors, numRegions, 126 needsOperandStorage ? 1 : 0); 127 byteSize += 128 llvm::alignTo(detail::OperandStorage::additionalAllocSize(numOperands), 129 alignof(Operation)); 130 void *rawMem = malloc(byteSize); 131 132 // Create the new Operation. 133 Operation *op = 134 ::new (rawMem) Operation(location, name, resultTypes, numSuccessors, 135 numRegions, attributes, needsOperandStorage); 136 137 assert((numSuccessors == 0 || !op->isKnownNonTerminator()) && 138 "unexpected successors in a non-terminator operation"); 139 140 // Initialize the results. 141 for (unsigned i = 0; i < numInlineResults; ++i) 142 new (op->getInlineResult(i)) detail::InLineOpResult(); 143 for (unsigned i = 0; i < numTrailingResults; ++i) 144 new (op->getTrailingResult(i)) detail::TrailingOpResult(i); 145 146 // Initialize the regions. 147 for (unsigned i = 0; i != numRegions; ++i) 148 new (&op->getRegion(i)) Region(op); 149 150 // Initialize the operands. 151 if (needsOperandStorage) 152 new (&op->getOperandStorage()) detail::OperandStorage(op, operands); 153 154 // Initialize the successors. 155 auto blockOperands = op->getBlockOperands(); 156 for (unsigned i = 0; i != numSuccessors; ++i) 157 new (&blockOperands[i]) BlockOperand(op, successors[i]); 158 159 return op; 160 } 161 162 Operation::Operation(Location location, OperationName name, 163 TypeRange resultTypes, unsigned numSuccessors, 164 unsigned numRegions, 165 const MutableDictionaryAttr &attributes, 166 bool hasOperandStorage) 167 : location(location), numSuccs(numSuccessors), numRegions(numRegions), 168 hasOperandStorage(hasOperandStorage), hasSingleResult(false), name(name), 169 attrs(attributes) { 170 assert(llvm::all_of(resultTypes, [](Type t) { return t; }) && 171 "unexpected null result type"); 172 if (!resultTypes.empty()) { 173 // If there is a single result it is stored in-place, otherwise use a tuple. 174 hasSingleResult = resultTypes.size() == 1; 175 if (hasSingleResult) 176 resultType = resultTypes.front(); 177 else 178 resultType = TupleType::get(resultTypes, location->getContext()); 179 } 180 } 181 182 // Operations are deleted through the destroy() member because they are 183 // allocated via malloc. 184 Operation::~Operation() { 185 assert(block == nullptr && "operation destroyed but still in a block"); 186 187 // Explicitly run the destructors for the operands. 188 if (hasOperandStorage) 189 getOperandStorage().~OperandStorage(); 190 191 // Explicitly run the destructors for the successors. 192 for (auto &successor : getBlockOperands()) 193 successor.~BlockOperand(); 194 195 // Explicitly destroy the regions. 196 for (auto ®ion : getRegions()) 197 region.~Region(); 198 } 199 200 /// Destroy this operation or one of its subclasses. 201 void Operation::destroy() { 202 this->~Operation(); 203 free(this); 204 } 205 206 /// Return the context this operation is associated with. 207 MLIRContext *Operation::getContext() { return location->getContext(); } 208 209 /// Return the dialect this operation is associated with, or nullptr if the 210 /// associated dialect is not registered. 211 Dialect *Operation::getDialect() { 212 if (auto *abstractOp = getAbstractOperation()) 213 return &abstractOp->dialect; 214 215 // If this operation hasn't been registered or doesn't have abstract 216 // operation, try looking up the dialect name in the context. 217 return getContext()->getLoadedDialect(getName().getDialect()); 218 } 219 220 Region *Operation::getParentRegion() { 221 return block ? block->getParent() : nullptr; 222 } 223 224 Operation *Operation::getParentOp() { 225 return block ? block->getParentOp() : nullptr; 226 } 227 228 /// Return true if this operation is a proper ancestor of the `other` 229 /// operation. 230 bool Operation::isProperAncestor(Operation *other) { 231 while ((other = other->getParentOp())) 232 if (this == other) 233 return true; 234 return false; 235 } 236 237 /// Replace any uses of 'from' with 'to' within this operation. 238 void Operation::replaceUsesOfWith(Value from, Value to) { 239 if (from == to) 240 return; 241 for (auto &operand : getOpOperands()) 242 if (operand.get() == from) 243 operand.set(to); 244 } 245 246 /// Replace the current operands of this operation with the ones provided in 247 /// 'operands'. 248 void Operation::setOperands(ValueRange operands) { 249 if (LLVM_LIKELY(hasOperandStorage)) 250 return getOperandStorage().setOperands(this, operands); 251 assert(operands.empty() && "setting operands without an operand storage"); 252 } 253 254 /// Replace the operands beginning at 'start' and ending at 'start' + 'length' 255 /// with the ones provided in 'operands'. 'operands' may be smaller or larger 256 /// than the range pointed to by 'start'+'length'. 257 void Operation::setOperands(unsigned start, unsigned length, 258 ValueRange operands) { 259 assert((start + length) <= getNumOperands() && 260 "invalid operand range specified"); 261 if (LLVM_LIKELY(hasOperandStorage)) 262 return getOperandStorage().setOperands(this, start, length, operands); 263 assert(operands.empty() && "setting operands without an operand storage"); 264 } 265 266 /// Insert the given operands into the operand list at the given 'index'. 267 void Operation::insertOperands(unsigned index, ValueRange operands) { 268 if (LLVM_LIKELY(hasOperandStorage)) 269 return setOperands(index, /*length=*/0, operands); 270 assert(operands.empty() && "inserting operands without an operand storage"); 271 } 272 273 //===----------------------------------------------------------------------===// 274 // Diagnostics 275 //===----------------------------------------------------------------------===// 276 277 /// Emit an error about fatal conditions with this operation, reporting up to 278 /// any diagnostic handlers that may be listening. 279 InFlightDiagnostic Operation::emitError(const Twine &message) { 280 InFlightDiagnostic diag = mlir::emitError(getLoc(), message); 281 if (getContext()->shouldPrintOpOnDiagnostic()) { 282 // Print out the operation explicitly here so that we can print the generic 283 // form. 284 // TODO: It would be nice if we could instead provide the 285 // specific printing flags when adding the operation as an argument to the 286 // diagnostic. 287 std::string printedOp; 288 { 289 llvm::raw_string_ostream os(printedOp); 290 print(os, OpPrintingFlags().printGenericOpForm().useLocalScope()); 291 } 292 diag.attachNote(getLoc()) << "see current operation: " << printedOp; 293 } 294 return diag; 295 } 296 297 /// Emit a warning about this operation, reporting up to any diagnostic 298 /// handlers that may be listening. 299 InFlightDiagnostic Operation::emitWarning(const Twine &message) { 300 InFlightDiagnostic diag = mlir::emitWarning(getLoc(), message); 301 if (getContext()->shouldPrintOpOnDiagnostic()) 302 diag.attachNote(getLoc()) << "see current operation: " << *this; 303 return diag; 304 } 305 306 /// Emit a remark about this operation, reporting up to any diagnostic 307 /// handlers that may be listening. 308 InFlightDiagnostic Operation::emitRemark(const Twine &message) { 309 InFlightDiagnostic diag = mlir::emitRemark(getLoc(), message); 310 if (getContext()->shouldPrintOpOnDiagnostic()) 311 diag.attachNote(getLoc()) << "see current operation: " << *this; 312 return diag; 313 } 314 315 //===----------------------------------------------------------------------===// 316 // Operation Ordering 317 //===----------------------------------------------------------------------===// 318 319 constexpr unsigned Operation::kInvalidOrderIdx; 320 constexpr unsigned Operation::kOrderStride; 321 322 /// Given an operation 'other' that is within the same parent block, return 323 /// whether the current operation is before 'other' in the operation list 324 /// of the parent block. 325 /// Note: This function has an average complexity of O(1), but worst case may 326 /// take O(N) where N is the number of operations within the parent block. 327 bool Operation::isBeforeInBlock(Operation *other) { 328 assert(block && "Operations without parent blocks have no order."); 329 assert(other && other->block == block && 330 "Expected other operation to have the same parent block."); 331 // If the order of the block is already invalid, directly recompute the 332 // parent. 333 if (!block->isOpOrderValid()) { 334 block->recomputeOpOrder(); 335 } else { 336 // Update the order either operation if necessary. 337 updateOrderIfNecessary(); 338 other->updateOrderIfNecessary(); 339 } 340 341 return orderIndex < other->orderIndex; 342 } 343 344 /// Update the order index of this operation of this operation if necessary, 345 /// potentially recomputing the order of the parent block. 346 void Operation::updateOrderIfNecessary() { 347 assert(block && "expected valid parent"); 348 349 // If the order is valid for this operation there is nothing to do. 350 if (hasValidOrder()) 351 return; 352 Operation *blockFront = &block->front(); 353 Operation *blockBack = &block->back(); 354 355 // This method is expected to only be invoked on blocks with more than one 356 // operation. 357 assert(blockFront != blockBack && "expected more than one operation"); 358 359 // If the operation is at the end of the block. 360 if (this == blockBack) { 361 Operation *prevNode = getPrevNode(); 362 if (!prevNode->hasValidOrder()) 363 return block->recomputeOpOrder(); 364 365 // Add the stride to the previous operation. 366 orderIndex = prevNode->orderIndex + kOrderStride; 367 return; 368 } 369 370 // If this is the first operation try to use the next operation to compute the 371 // ordering. 372 if (this == blockFront) { 373 Operation *nextNode = getNextNode(); 374 if (!nextNode->hasValidOrder()) 375 return block->recomputeOpOrder(); 376 // There is no order to give this operation. 377 if (nextNode->orderIndex == 0) 378 return block->recomputeOpOrder(); 379 380 // If we can't use the stride, just take the middle value left. This is safe 381 // because we know there is at least one valid index to assign to. 382 if (nextNode->orderIndex <= kOrderStride) 383 orderIndex = (nextNode->orderIndex / 2); 384 else 385 orderIndex = kOrderStride; 386 return; 387 } 388 389 // Otherwise, this operation is between two others. Place this operation in 390 // the middle of the previous and next if possible. 391 Operation *prevNode = getPrevNode(), *nextNode = getNextNode(); 392 if (!prevNode->hasValidOrder() || !nextNode->hasValidOrder()) 393 return block->recomputeOpOrder(); 394 unsigned prevOrder = prevNode->orderIndex, nextOrder = nextNode->orderIndex; 395 396 // Check to see if there is a valid order between the two. 397 if (prevOrder + 1 == nextOrder) 398 return block->recomputeOpOrder(); 399 orderIndex = prevOrder + ((nextOrder - prevOrder) / 2); 400 } 401 402 //===----------------------------------------------------------------------===// 403 // ilist_traits for Operation 404 //===----------------------------------------------------------------------===// 405 406 auto llvm::ilist_detail::SpecificNodeAccess< 407 typename llvm::ilist_detail::compute_node_options< 408 ::mlir::Operation>::type>::getNodePtr(pointer N) -> node_type * { 409 return NodeAccess::getNodePtr<OptionsT>(N); 410 } 411 412 auto llvm::ilist_detail::SpecificNodeAccess< 413 typename llvm::ilist_detail::compute_node_options< 414 ::mlir::Operation>::type>::getNodePtr(const_pointer N) 415 -> const node_type * { 416 return NodeAccess::getNodePtr<OptionsT>(N); 417 } 418 419 auto llvm::ilist_detail::SpecificNodeAccess< 420 typename llvm::ilist_detail::compute_node_options< 421 ::mlir::Operation>::type>::getValuePtr(node_type *N) -> pointer { 422 return NodeAccess::getValuePtr<OptionsT>(N); 423 } 424 425 auto llvm::ilist_detail::SpecificNodeAccess< 426 typename llvm::ilist_detail::compute_node_options< 427 ::mlir::Operation>::type>::getValuePtr(const node_type *N) 428 -> const_pointer { 429 return NodeAccess::getValuePtr<OptionsT>(N); 430 } 431 432 void llvm::ilist_traits<::mlir::Operation>::deleteNode(Operation *op) { 433 op->destroy(); 434 } 435 436 Block *llvm::ilist_traits<::mlir::Operation>::getContainingBlock() { 437 size_t Offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr)))); 438 iplist<Operation> *Anchor(static_cast<iplist<Operation> *>(this)); 439 return reinterpret_cast<Block *>(reinterpret_cast<char *>(Anchor) - Offset); 440 } 441 442 /// This is a trait method invoked when an operation is added to a block. We 443 /// keep the block pointer up to date. 444 void llvm::ilist_traits<::mlir::Operation>::addNodeToList(Operation *op) { 445 assert(!op->getBlock() && "already in an operation block!"); 446 op->block = getContainingBlock(); 447 448 // Invalidate the order on the operation. 449 op->orderIndex = Operation::kInvalidOrderIdx; 450 } 451 452 /// This is a trait method invoked when an operation is removed from a block. 453 /// We keep the block pointer up to date. 454 void llvm::ilist_traits<::mlir::Operation>::removeNodeFromList(Operation *op) { 455 assert(op->block && "not already in an operation block!"); 456 op->block = nullptr; 457 } 458 459 /// This is a trait method invoked when an operation is moved from one block 460 /// to another. We keep the block pointer up to date. 461 void llvm::ilist_traits<::mlir::Operation>::transferNodesFromList( 462 ilist_traits<Operation> &otherList, op_iterator first, op_iterator last) { 463 Block *curParent = getContainingBlock(); 464 465 // Invalidate the ordering of the parent block. 466 curParent->invalidateOpOrder(); 467 468 // If we are transferring operations within the same block, the block 469 // pointer doesn't need to be updated. 470 if (curParent == otherList.getContainingBlock()) 471 return; 472 473 // Update the 'block' member of each operation. 474 for (; first != last; ++first) 475 first->block = curParent; 476 } 477 478 /// Remove this operation (and its descendants) from its Block and delete 479 /// all of them. 480 void Operation::erase() { 481 if (auto *parent = getBlock()) 482 parent->getOperations().erase(this); 483 else 484 destroy(); 485 } 486 487 /// Unlink this operation from its current block and insert it right before 488 /// `existingOp` which may be in the same or another block in the same 489 /// function. 490 void Operation::moveBefore(Operation *existingOp) { 491 moveBefore(existingOp->getBlock(), existingOp->getIterator()); 492 } 493 494 /// Unlink this operation from its current basic block and insert it right 495 /// before `iterator` in the specified basic block. 496 void Operation::moveBefore(Block *block, 497 llvm::iplist<Operation>::iterator iterator) { 498 block->getOperations().splice(iterator, getBlock()->getOperations(), 499 getIterator()); 500 } 501 502 /// Unlink this operation from its current block and insert it right after 503 /// `existingOp` which may be in the same or another block in the same function. 504 void Operation::moveAfter(Operation *existingOp) { 505 moveAfter(existingOp->getBlock(), existingOp->getIterator()); 506 } 507 508 /// Unlink this operation from its current block and insert it right after 509 /// `iterator` in the specified block. 510 void Operation::moveAfter(Block *block, 511 llvm::iplist<Operation>::iterator iterator) { 512 assert(iterator != block->end() && "cannot move after end of block"); 513 moveBefore(&*std::next(iterator)); 514 } 515 516 /// This drops all operand uses from this operation, which is an essential 517 /// step in breaking cyclic dependences between references when they are to 518 /// be deleted. 519 void Operation::dropAllReferences() { 520 for (auto &op : getOpOperands()) 521 op.drop(); 522 523 for (auto ®ion : getRegions()) 524 region.dropAllReferences(); 525 526 for (auto &dest : getBlockOperands()) 527 dest.drop(); 528 } 529 530 /// This drops all uses of any values defined by this operation or its nested 531 /// regions, wherever they are located. 532 void Operation::dropAllDefinedValueUses() { 533 dropAllUses(); 534 535 for (auto ®ion : getRegions()) 536 for (auto &block : region) 537 block.dropAllDefinedValueUses(); 538 } 539 540 /// Return the number of results held by this operation. 541 unsigned Operation::getNumResults() { 542 if (!resultType) 543 return 0; 544 return hasSingleResult ? 1 : resultType.cast<TupleType>().size(); 545 } 546 547 auto Operation::getResultTypes() -> result_type_range { 548 if (!resultType) 549 return llvm::None; 550 if (hasSingleResult) 551 return resultType; 552 return resultType.cast<TupleType>().getTypes(); 553 } 554 555 void Operation::setSuccessor(Block *block, unsigned index) { 556 assert(index < getNumSuccessors()); 557 getBlockOperands()[index].set(block); 558 } 559 560 /// Attempt to fold this operation using the Op's registered foldHook. 561 LogicalResult Operation::fold(ArrayRef<Attribute> operands, 562 SmallVectorImpl<OpFoldResult> &results) { 563 // If we have a registered operation definition matching this one, use it to 564 // try to constant fold the operation. 565 auto *abstractOp = getAbstractOperation(); 566 if (abstractOp && succeeded(abstractOp->foldHook(this, operands, results))) 567 return success(); 568 569 // Otherwise, fall back on the dialect hook to handle it. 570 Dialect *dialect = getDialect(); 571 if (!dialect) 572 return failure(); 573 574 auto *interface = dialect->getRegisteredInterface<DialectFoldInterface>(); 575 if (!interface) 576 return failure(); 577 578 return interface->fold(this, operands, results); 579 } 580 581 /// Emit an error with the op name prefixed, like "'dim' op " which is 582 /// convenient for verifiers. 583 InFlightDiagnostic Operation::emitOpError(const Twine &message) { 584 return emitError() << "'" << getName() << "' op " << message; 585 } 586 587 //===----------------------------------------------------------------------===// 588 // Operation Cloning 589 //===----------------------------------------------------------------------===// 590 591 /// Create a deep copy of this operation but keep the operation regions empty. 592 /// Operands are remapped using `mapper` (if present), and `mapper` is updated 593 /// to contain the results. 594 Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper) { 595 SmallVector<Value, 8> operands; 596 SmallVector<Block *, 2> successors; 597 598 // Remap the operands. 599 operands.reserve(getNumOperands()); 600 for (auto opValue : getOperands()) 601 operands.push_back(mapper.lookupOrDefault(opValue)); 602 603 // Remap the successors. 604 successors.reserve(getNumSuccessors()); 605 for (Block *successor : getSuccessors()) 606 successors.push_back(mapper.lookupOrDefault(successor)); 607 608 // Create the new operation. 609 auto *newOp = create(getLoc(), getName(), getResultTypes(), operands, attrs, 610 successors, getNumRegions()); 611 612 // Remember the mapping of any results. 613 for (unsigned i = 0, e = getNumResults(); i != e; ++i) 614 mapper.map(getResult(i), newOp->getResult(i)); 615 616 return newOp; 617 } 618 619 Operation *Operation::cloneWithoutRegions() { 620 BlockAndValueMapping mapper; 621 return cloneWithoutRegions(mapper); 622 } 623 624 /// Create a deep copy of this operation, remapping any operands that use 625 /// values outside of the operation using the map that is provided (leaving 626 /// them alone if no entry is present). Replaces references to cloned 627 /// sub-operations to the corresponding operation that is copied, and adds 628 /// those mappings to the map. 629 Operation *Operation::clone(BlockAndValueMapping &mapper) { 630 auto *newOp = cloneWithoutRegions(mapper); 631 632 // Clone the regions. 633 for (unsigned i = 0; i != numRegions; ++i) 634 getRegion(i).cloneInto(&newOp->getRegion(i), mapper); 635 636 return newOp; 637 } 638 639 Operation *Operation::clone() { 640 BlockAndValueMapping mapper; 641 return clone(mapper); 642 } 643 644 //===----------------------------------------------------------------------===// 645 // OpState trait class. 646 //===----------------------------------------------------------------------===// 647 648 // The fallback for the parser is to reject the custom assembly form. 649 ParseResult OpState::parse(OpAsmParser &parser, OperationState &result) { 650 return parser.emitError(parser.getNameLoc(), "has no custom assembly form"); 651 } 652 653 // The fallback for the printer is to print in the generic assembly form. 654 void OpState::print(Operation *op, OpAsmPrinter &p) { p.printGenericOp(op); } 655 656 /// Emit an error about fatal conditions with this operation, reporting up to 657 /// any diagnostic handlers that may be listening. 658 InFlightDiagnostic OpState::emitError(const Twine &message) { 659 return getOperation()->emitError(message); 660 } 661 662 /// Emit an error with the op name prefixed, like "'dim' op " which is 663 /// convenient for verifiers. 664 InFlightDiagnostic OpState::emitOpError(const Twine &message) { 665 return getOperation()->emitOpError(message); 666 } 667 668 /// Emit a warning about this operation, reporting up to any diagnostic 669 /// handlers that may be listening. 670 InFlightDiagnostic OpState::emitWarning(const Twine &message) { 671 return getOperation()->emitWarning(message); 672 } 673 674 /// Emit a remark about this operation, reporting up to any diagnostic 675 /// handlers that may be listening. 676 InFlightDiagnostic OpState::emitRemark(const Twine &message) { 677 return getOperation()->emitRemark(message); 678 } 679 680 //===----------------------------------------------------------------------===// 681 // Op Trait implementations 682 //===----------------------------------------------------------------------===// 683 684 OpFoldResult OpTrait::impl::foldIdempotent(Operation *op) { 685 auto *argumentOp = op->getOperand(0).getDefiningOp(); 686 if (argumentOp && op->getName() == argumentOp->getName()) { 687 // Replace the outer operation output with the inner operation. 688 return op->getOperand(0); 689 } 690 691 return {}; 692 } 693 694 OpFoldResult OpTrait::impl::foldInvolution(Operation *op) { 695 auto *argumentOp = op->getOperand(0).getDefiningOp(); 696 if (argumentOp && op->getName() == argumentOp->getName()) { 697 // Replace the outer involutions output with inner's input. 698 return argumentOp->getOperand(0); 699 } 700 701 return {}; 702 } 703 704 LogicalResult OpTrait::impl::verifyZeroOperands(Operation *op) { 705 if (op->getNumOperands() != 0) 706 return op->emitOpError() << "requires zero operands"; 707 return success(); 708 } 709 710 LogicalResult OpTrait::impl::verifyOneOperand(Operation *op) { 711 if (op->getNumOperands() != 1) 712 return op->emitOpError() << "requires a single operand"; 713 return success(); 714 } 715 716 LogicalResult OpTrait::impl::verifyNOperands(Operation *op, 717 unsigned numOperands) { 718 if (op->getNumOperands() != numOperands) { 719 return op->emitOpError() << "expected " << numOperands 720 << " operands, but found " << op->getNumOperands(); 721 } 722 return success(); 723 } 724 725 LogicalResult OpTrait::impl::verifyAtLeastNOperands(Operation *op, 726 unsigned numOperands) { 727 if (op->getNumOperands() < numOperands) 728 return op->emitOpError() 729 << "expected " << numOperands << " or more operands"; 730 return success(); 731 } 732 733 /// If this is a vector type, or a tensor type, return the scalar element type 734 /// that it is built around, otherwise return the type unmodified. 735 static Type getTensorOrVectorElementType(Type type) { 736 if (auto vec = type.dyn_cast<VectorType>()) 737 return vec.getElementType(); 738 739 // Look through tensor<vector<...>> to find the underlying element type. 740 if (auto tensor = type.dyn_cast<TensorType>()) 741 return getTensorOrVectorElementType(tensor.getElementType()); 742 return type; 743 } 744 745 LogicalResult OpTrait::impl::verifyIsIdempotent(Operation *op) { 746 // FIXME: Add back check for no side effects on operation. 747 // Currently adding it would cause the shared library build 748 // to fail since there would be a dependency of IR on SideEffectInterfaces 749 // which is cyclical. 750 return success(); 751 } 752 753 LogicalResult OpTrait::impl::verifyIsInvolution(Operation *op) { 754 // FIXME: Add back check for no side effects on operation. 755 // Currently adding it would cause the shared library build 756 // to fail since there would be a dependency of IR on SideEffectInterfaces 757 // which is cyclical. 758 return success(); 759 } 760 761 LogicalResult 762 OpTrait::impl::verifyOperandsAreSignlessIntegerLike(Operation *op) { 763 for (auto opType : op->getOperandTypes()) { 764 auto type = getTensorOrVectorElementType(opType); 765 if (!type.isSignlessIntOrIndex()) 766 return op->emitOpError() << "requires an integer or index type"; 767 } 768 return success(); 769 } 770 771 LogicalResult OpTrait::impl::verifyOperandsAreFloatLike(Operation *op) { 772 for (auto opType : op->getOperandTypes()) { 773 auto type = getTensorOrVectorElementType(opType); 774 if (!type.isa<FloatType>()) 775 return op->emitOpError("requires a float type"); 776 } 777 return success(); 778 } 779 780 LogicalResult OpTrait::impl::verifySameTypeOperands(Operation *op) { 781 // Zero or one operand always have the "same" type. 782 unsigned nOperands = op->getNumOperands(); 783 if (nOperands < 2) 784 return success(); 785 786 auto type = op->getOperand(0).getType(); 787 for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) 788 if (opType != type) 789 return op->emitOpError() << "requires all operands to have the same type"; 790 return success(); 791 } 792 793 LogicalResult OpTrait::impl::verifyZeroRegion(Operation *op) { 794 if (op->getNumRegions() != 0) 795 return op->emitOpError() << "requires zero regions"; 796 return success(); 797 } 798 799 LogicalResult OpTrait::impl::verifyOneRegion(Operation *op) { 800 if (op->getNumRegions() != 1) 801 return op->emitOpError() << "requires one region"; 802 return success(); 803 } 804 805 LogicalResult OpTrait::impl::verifyNRegions(Operation *op, 806 unsigned numRegions) { 807 if (op->getNumRegions() != numRegions) 808 return op->emitOpError() << "expected " << numRegions << " regions"; 809 return success(); 810 } 811 812 LogicalResult OpTrait::impl::verifyAtLeastNRegions(Operation *op, 813 unsigned numRegions) { 814 if (op->getNumRegions() < numRegions) 815 return op->emitOpError() << "expected " << numRegions << " or more regions"; 816 return success(); 817 } 818 819 LogicalResult OpTrait::impl::verifyZeroResult(Operation *op) { 820 if (op->getNumResults() != 0) 821 return op->emitOpError() << "requires zero results"; 822 return success(); 823 } 824 825 LogicalResult OpTrait::impl::verifyOneResult(Operation *op) { 826 if (op->getNumResults() != 1) 827 return op->emitOpError() << "requires one result"; 828 return success(); 829 } 830 831 LogicalResult OpTrait::impl::verifyNResults(Operation *op, 832 unsigned numOperands) { 833 if (op->getNumResults() != numOperands) 834 return op->emitOpError() << "expected " << numOperands << " results"; 835 return success(); 836 } 837 838 LogicalResult OpTrait::impl::verifyAtLeastNResults(Operation *op, 839 unsigned numOperands) { 840 if (op->getNumResults() < numOperands) 841 return op->emitOpError() 842 << "expected " << numOperands << " or more results"; 843 return success(); 844 } 845 846 LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) { 847 if (failed(verifyAtLeastNOperands(op, 1))) 848 return failure(); 849 850 auto type = op->getOperand(0).getType(); 851 for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) { 852 if (failed(verifyCompatibleShape(opType, type))) 853 return op->emitOpError() << "requires the same shape for all operands"; 854 } 855 return success(); 856 } 857 858 LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) { 859 if (failed(verifyAtLeastNOperands(op, 1)) || 860 failed(verifyAtLeastNResults(op, 1))) 861 return failure(); 862 863 auto type = op->getOperand(0).getType(); 864 for (auto resultType : op->getResultTypes()) { 865 if (failed(verifyCompatibleShape(resultType, type))) 866 return op->emitOpError() 867 << "requires the same shape for all operands and results"; 868 } 869 for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) { 870 if (failed(verifyCompatibleShape(opType, type))) 871 return op->emitOpError() 872 << "requires the same shape for all operands and results"; 873 } 874 return success(); 875 } 876 877 LogicalResult OpTrait::impl::verifySameOperandsElementType(Operation *op) { 878 if (failed(verifyAtLeastNOperands(op, 1))) 879 return failure(); 880 auto elementType = getElementTypeOrSelf(op->getOperand(0)); 881 882 for (auto operand : llvm::drop_begin(op->getOperands(), 1)) { 883 if (getElementTypeOrSelf(operand) != elementType) 884 return op->emitOpError("requires the same element type for all operands"); 885 } 886 887 return success(); 888 } 889 890 LogicalResult 891 OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) { 892 if (failed(verifyAtLeastNOperands(op, 1)) || 893 failed(verifyAtLeastNResults(op, 1))) 894 return failure(); 895 896 auto elementType = getElementTypeOrSelf(op->getResult(0)); 897 898 // Verify result element type matches first result's element type. 899 for (auto result : llvm::drop_begin(op->getResults(), 1)) { 900 if (getElementTypeOrSelf(result) != elementType) 901 return op->emitOpError( 902 "requires the same element type for all operands and results"); 903 } 904 905 // Verify operand's element type matches first result's element type. 906 for (auto operand : op->getOperands()) { 907 if (getElementTypeOrSelf(operand) != elementType) 908 return op->emitOpError( 909 "requires the same element type for all operands and results"); 910 } 911 912 return success(); 913 } 914 915 LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) { 916 if (failed(verifyAtLeastNOperands(op, 1)) || 917 failed(verifyAtLeastNResults(op, 1))) 918 return failure(); 919 920 auto type = op->getResult(0).getType(); 921 auto elementType = getElementTypeOrSelf(type); 922 for (auto resultType : op->getResultTypes().drop_front(1)) { 923 if (getElementTypeOrSelf(resultType) != elementType || 924 failed(verifyCompatibleShape(resultType, type))) 925 return op->emitOpError() 926 << "requires the same type for all operands and results"; 927 } 928 for (auto opType : op->getOperandTypes()) { 929 if (getElementTypeOrSelf(opType) != elementType || 930 failed(verifyCompatibleShape(opType, type))) 931 return op->emitOpError() 932 << "requires the same type for all operands and results"; 933 } 934 return success(); 935 } 936 937 LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) { 938 Block *block = op->getBlock(); 939 // Verify that the operation is at the end of the respective parent block. 940 if (!block || &block->back() != op) 941 return op->emitOpError("must be the last operation in the parent block"); 942 return success(); 943 } 944 945 static LogicalResult verifyTerminatorSuccessors(Operation *op) { 946 auto *parent = op->getParentRegion(); 947 948 // Verify that the operands lines up with the BB arguments in the successor. 949 for (Block *succ : op->getSuccessors()) 950 if (succ->getParent() != parent) 951 return op->emitError("reference to block defined in another region"); 952 return success(); 953 } 954 955 LogicalResult OpTrait::impl::verifyZeroSuccessor(Operation *op) { 956 if (op->getNumSuccessors() != 0) { 957 return op->emitOpError("requires 0 successors but found ") 958 << op->getNumSuccessors(); 959 } 960 return success(); 961 } 962 963 LogicalResult OpTrait::impl::verifyOneSuccessor(Operation *op) { 964 if (op->getNumSuccessors() != 1) { 965 return op->emitOpError("requires 1 successor but found ") 966 << op->getNumSuccessors(); 967 } 968 return verifyTerminatorSuccessors(op); 969 } 970 LogicalResult OpTrait::impl::verifyNSuccessors(Operation *op, 971 unsigned numSuccessors) { 972 if (op->getNumSuccessors() != numSuccessors) { 973 return op->emitOpError("requires ") 974 << numSuccessors << " successors but found " 975 << op->getNumSuccessors(); 976 } 977 return verifyTerminatorSuccessors(op); 978 } 979 LogicalResult OpTrait::impl::verifyAtLeastNSuccessors(Operation *op, 980 unsigned numSuccessors) { 981 if (op->getNumSuccessors() < numSuccessors) { 982 return op->emitOpError("requires at least ") 983 << numSuccessors << " successors but found " 984 << op->getNumSuccessors(); 985 } 986 return verifyTerminatorSuccessors(op); 987 } 988 989 LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) { 990 for (auto resultType : op->getResultTypes()) { 991 auto elementType = getTensorOrVectorElementType(resultType); 992 bool isBoolType = elementType.isInteger(1); 993 if (!isBoolType) 994 return op->emitOpError() << "requires a bool result type"; 995 } 996 997 return success(); 998 } 999 1000 LogicalResult OpTrait::impl::verifyResultsAreFloatLike(Operation *op) { 1001 for (auto resultType : op->getResultTypes()) 1002 if (!getTensorOrVectorElementType(resultType).isa<FloatType>()) 1003 return op->emitOpError() << "requires a floating point type"; 1004 1005 return success(); 1006 } 1007 1008 LogicalResult 1009 OpTrait::impl::verifyResultsAreSignlessIntegerLike(Operation *op) { 1010 for (auto resultType : op->getResultTypes()) 1011 if (!getTensorOrVectorElementType(resultType).isSignlessIntOrIndex()) 1012 return op->emitOpError() << "requires an integer or index type"; 1013 return success(); 1014 } 1015 1016 static LogicalResult verifyValueSizeAttr(Operation *op, StringRef attrName, 1017 bool isOperand) { 1018 auto sizeAttr = op->getAttrOfType<DenseIntElementsAttr>(attrName); 1019 if (!sizeAttr) 1020 return op->emitOpError("requires 1D vector attribute '") << attrName << "'"; 1021 1022 auto sizeAttrType = sizeAttr.getType().dyn_cast<VectorType>(); 1023 if (!sizeAttrType || sizeAttrType.getRank() != 1) 1024 return op->emitOpError("requires 1D vector attribute '") << attrName << "'"; 1025 1026 if (llvm::any_of(sizeAttr.getIntValues(), [](const APInt &element) { 1027 return !element.isNonNegative(); 1028 })) 1029 return op->emitOpError("'") 1030 << attrName << "' attribute cannot have negative elements"; 1031 1032 size_t totalCount = std::accumulate( 1033 sizeAttr.begin(), sizeAttr.end(), 0, 1034 [](unsigned all, APInt one) { return all + one.getZExtValue(); }); 1035 1036 if (isOperand && totalCount != op->getNumOperands()) 1037 return op->emitOpError("operand count (") 1038 << op->getNumOperands() << ") does not match with the total size (" 1039 << totalCount << ") specified in attribute '" << attrName << "'"; 1040 else if (!isOperand && totalCount != op->getNumResults()) 1041 return op->emitOpError("result count (") 1042 << op->getNumResults() << ") does not match with the total size (" 1043 << totalCount << ") specified in attribute '" << attrName << "'"; 1044 return success(); 1045 } 1046 1047 LogicalResult OpTrait::impl::verifyOperandSizeAttr(Operation *op, 1048 StringRef attrName) { 1049 return verifyValueSizeAttr(op, attrName, /*isOperand=*/true); 1050 } 1051 1052 LogicalResult OpTrait::impl::verifyResultSizeAttr(Operation *op, 1053 StringRef attrName) { 1054 return verifyValueSizeAttr(op, attrName, /*isOperand=*/false); 1055 } 1056 1057 LogicalResult OpTrait::impl::verifyNoRegionArguments(Operation *op) { 1058 for (Region ®ion : op->getRegions()) { 1059 if (region.empty()) 1060 continue; 1061 1062 if (region.getNumArguments() != 0) { 1063 if (op->getNumRegions() > 1) 1064 return op->emitOpError("region #") 1065 << region.getRegionNumber() << " should have no arguments"; 1066 else 1067 return op->emitOpError("region should have no arguments"); 1068 } 1069 } 1070 return success(); 1071 } 1072 1073 /// Checks if two ShapedTypes are the same, ignoring the element type. 1074 static bool areSameShapedTypeIgnoringElementType(ShapedType a, ShapedType b) { 1075 if (a.getTypeID() != b.getTypeID()) 1076 return false; 1077 if (!a.hasRank()) 1078 return !b.hasRank(); 1079 return a.getShape() == b.getShape(); 1080 } 1081 1082 LogicalResult OpTrait::impl::verifyElementwiseMappable(Operation *op) { 1083 auto isMappableType = [](Type type) { 1084 return type.isa<VectorType, TensorType>(); 1085 }; 1086 auto resultMappableTypes = llvm::to_vector<1>( 1087 llvm::make_filter_range(op->getResultTypes(), isMappableType)); 1088 auto operandMappableTypes = llvm::to_vector<2>( 1089 llvm::make_filter_range(op->getOperandTypes(), isMappableType)); 1090 1091 // If the op only has scalar operand/result types, then we have nothing to 1092 // check. 1093 if (resultMappableTypes.empty() && operandMappableTypes.empty()) 1094 return success(); 1095 1096 if (!resultMappableTypes.empty() && operandMappableTypes.empty()) 1097 return op->emitOpError("if a result is non-scalar, then at least one " 1098 "operand must be non-scalar"); 1099 1100 assert(!operandMappableTypes.empty()); 1101 1102 if (resultMappableTypes.empty()) 1103 return op->emitOpError("if an operand is non-scalar, then there must be at " 1104 "least one non-scalar result"); 1105 1106 if (resultMappableTypes.size() != op->getNumResults()) 1107 return op->emitOpError( 1108 "if an operand is non-scalar, then all results must be non-scalar"); 1109 1110 auto mustMatchType = operandMappableTypes[0].cast<ShapedType>(); 1111 for (auto type : 1112 llvm::concat<Type>(resultMappableTypes, operandMappableTypes)) { 1113 if (!areSameShapedTypeIgnoringElementType(type.cast<ShapedType>(), 1114 mustMatchType)) { 1115 return op->emitOpError() << "all non-scalar operands/results must have " 1116 "the same shape and base type: found " 1117 << type << " and " << mustMatchType; 1118 } 1119 } 1120 1121 return success(); 1122 } 1123 1124 //===----------------------------------------------------------------------===// 1125 // BinaryOp implementation 1126 //===----------------------------------------------------------------------===// 1127 1128 // These functions are out-of-line implementations of the methods in BinaryOp, 1129 // which avoids them being template instantiated/duplicated. 1130 1131 void impl::buildBinaryOp(OpBuilder &builder, OperationState &result, Value lhs, 1132 Value rhs) { 1133 assert(lhs.getType() == rhs.getType()); 1134 result.addOperands({lhs, rhs}); 1135 result.types.push_back(lhs.getType()); 1136 } 1137 1138 ParseResult impl::parseOneResultSameOperandTypeOp(OpAsmParser &parser, 1139 OperationState &result) { 1140 SmallVector<OpAsmParser::OperandType, 2> ops; 1141 Type type; 1142 return failure(parser.parseOperandList(ops) || 1143 parser.parseOptionalAttrDict(result.attributes) || 1144 parser.parseColonType(type) || 1145 parser.resolveOperands(ops, type, result.operands) || 1146 parser.addTypeToList(type, result.types)); 1147 } 1148 1149 void impl::printOneResultOp(Operation *op, OpAsmPrinter &p) { 1150 assert(op->getNumResults() == 1 && "op should have one result"); 1151 1152 // If not all the operand and result types are the same, just use the 1153 // generic assembly form to avoid omitting information in printing. 1154 auto resultType = op->getResult(0).getType(); 1155 if (llvm::any_of(op->getOperandTypes(), 1156 [&](Type type) { return type != resultType; })) { 1157 p.printGenericOp(op); 1158 return; 1159 } 1160 1161 p << op->getName() << ' '; 1162 p.printOperands(op->getOperands()); 1163 p.printOptionalAttrDict(op->getAttrs()); 1164 // Now we can output only one type for all operands and the result. 1165 p << " : " << resultType; 1166 } 1167 1168 //===----------------------------------------------------------------------===// 1169 // CastOp implementation 1170 //===----------------------------------------------------------------------===// 1171 1172 void impl::buildCastOp(OpBuilder &builder, OperationState &result, Value source, 1173 Type destType) { 1174 result.addOperands(source); 1175 result.addTypes(destType); 1176 } 1177 1178 ParseResult impl::parseCastOp(OpAsmParser &parser, OperationState &result) { 1179 OpAsmParser::OperandType srcInfo; 1180 Type srcType, dstType; 1181 return failure(parser.parseOperand(srcInfo) || 1182 parser.parseOptionalAttrDict(result.attributes) || 1183 parser.parseColonType(srcType) || 1184 parser.resolveOperand(srcInfo, srcType, result.operands) || 1185 parser.parseKeywordType("to", dstType) || 1186 parser.addTypeToList(dstType, result.types)); 1187 } 1188 1189 void impl::printCastOp(Operation *op, OpAsmPrinter &p) { 1190 p << op->getName() << ' ' << op->getOperand(0); 1191 p.printOptionalAttrDict(op->getAttrs()); 1192 p << " : " << op->getOperand(0).getType() << " to " 1193 << op->getResult(0).getType(); 1194 } 1195 1196 Value impl::foldCastOp(Operation *op) { 1197 // Identity cast 1198 if (op->getOperand(0).getType() == op->getResult(0).getType()) 1199 return op->getOperand(0); 1200 return nullptr; 1201 } 1202 1203 //===----------------------------------------------------------------------===// 1204 // Misc. utils 1205 //===----------------------------------------------------------------------===// 1206 1207 /// Insert an operation, generated by `buildTerminatorOp`, at the end of the 1208 /// region's only block if it does not have a terminator already. If the region 1209 /// is empty, insert a new block first. `buildTerminatorOp` should return the 1210 /// terminator operation to insert. 1211 void impl::ensureRegionTerminator( 1212 Region ®ion, OpBuilder &builder, Location loc, 1213 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { 1214 OpBuilder::InsertionGuard guard(builder); 1215 if (region.empty()) 1216 builder.createBlock(®ion); 1217 1218 Block &block = region.back(); 1219 if (!block.empty() && block.back().isKnownTerminator()) 1220 return; 1221 1222 builder.setInsertionPointToEnd(&block); 1223 builder.insert(buildTerminatorOp(builder, loc)); 1224 } 1225 1226 /// Create a simple OpBuilder and forward to the OpBuilder version of this 1227 /// function. 1228 void impl::ensureRegionTerminator( 1229 Region ®ion, Builder &builder, Location loc, 1230 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { 1231 OpBuilder opBuilder(builder.getContext()); 1232 ensureRegionTerminator(region, opBuilder, loc, buildTerminatorOp); 1233 } 1234 1235 //===----------------------------------------------------------------------===// 1236 // UseIterator 1237 //===----------------------------------------------------------------------===// 1238 1239 Operation::UseIterator::UseIterator(Operation *op, bool end) 1240 : op(op), res(end ? op->result_end() : op->result_begin()) { 1241 // Only initialize current use if there are results/can be uses. 1242 if (op->getNumResults()) 1243 skipOverResultsWithNoUsers(); 1244 } 1245 1246 Operation::UseIterator &Operation::UseIterator::operator++() { 1247 // We increment over uses, if we reach the last use then move to next 1248 // result. 1249 if (use != (*res).use_end()) 1250 ++use; 1251 if (use == (*res).use_end()) { 1252 ++res; 1253 skipOverResultsWithNoUsers(); 1254 } 1255 return *this; 1256 } 1257 1258 void Operation::UseIterator::skipOverResultsWithNoUsers() { 1259 while (res != op->result_end() && (*res).use_empty()) 1260 ++res; 1261 1262 // If we are at the last result, then set use to first use of 1263 // first result (sentinel value used for end). 1264 if (res == op->result_end()) 1265 use = {}; 1266 else 1267 use = (*res).use_begin(); 1268 } 1269