1 //===- Operation.cpp - Operation support code -----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Operation.h" 10 #include "mlir/IR/BlockAndValueMapping.h" 11 #include "mlir/IR/BuiltinTypes.h" 12 #include "mlir/IR/Dialect.h" 13 #include "mlir/IR/OpImplementation.h" 14 #include "mlir/IR/PatternMatch.h" 15 #include "mlir/IR/TypeUtilities.h" 16 #include "mlir/Interfaces/FoldInterfaces.h" 17 #include "llvm/ADT/StringExtras.h" 18 #include <numeric> 19 20 using namespace mlir; 21 22 //===----------------------------------------------------------------------===// 23 // Operation 24 //===----------------------------------------------------------------------===// 25 26 /// Create a new Operation from operation state. 27 Operation *Operation::create(const OperationState &state) { 28 return create(state.location, state.name, state.types, state.operands, 29 state.attributes.getDictionary(state.getContext()), 30 state.successors, state.regions); 31 } 32 33 /// Create a new Operation with the specific fields. 34 Operation *Operation::create(Location location, OperationName name, 35 TypeRange resultTypes, ValueRange operands, 36 NamedAttrList &&attributes, BlockRange successors, 37 RegionRange regions) { 38 unsigned numRegions = regions.size(); 39 Operation *op = create(location, name, resultTypes, operands, 40 std::move(attributes), successors, numRegions); 41 for (unsigned i = 0; i < numRegions; ++i) 42 if (regions[i]) 43 op->getRegion(i).takeBody(*regions[i]); 44 return op; 45 } 46 47 /// Overload of create that takes an existing DictionaryAttr to avoid 48 /// unnecessarily uniquing a list of attributes. 49 Operation *Operation::create(Location location, OperationName name, 50 TypeRange resultTypes, ValueRange operands, 51 NamedAttrList &&attributes, BlockRange successors, 52 unsigned numRegions) { 53 assert(llvm::all_of(resultTypes, [](Type t) { return t; }) && 54 "unexpected null result type"); 55 56 // We only need to allocate additional memory for a subset of results. 57 unsigned numTrailingResults = OpResult::getNumTrailing(resultTypes.size()); 58 unsigned numInlineResults = OpResult::getNumInline(resultTypes.size()); 59 unsigned numSuccessors = successors.size(); 60 unsigned numOperands = operands.size(); 61 unsigned numResults = resultTypes.size(); 62 63 // If the operation is known to have no operands, don't allocate an operand 64 // storage. 65 bool needsOperandStorage = 66 operands.empty() ? !name.hasTrait<OpTrait::ZeroOperands>() : true; 67 68 // Compute the byte size for the operation and the operand storage. This takes 69 // into account the size of the operation, its trailing objects, and its 70 // prefixed objects. 71 size_t byteSize = 72 totalSizeToAlloc<detail::OperandStorage, BlockOperand, Region, OpOperand>( 73 needsOperandStorage ? 1 : 0, numSuccessors, numRegions, numOperands); 74 size_t prefixByteSize = llvm::alignTo( 75 Operation::prefixAllocSize(numTrailingResults, numInlineResults), 76 alignof(Operation)); 77 char *mallocMem = reinterpret_cast<char *>(malloc(byteSize + prefixByteSize)); 78 void *rawMem = mallocMem + prefixByteSize; 79 80 // Create the new Operation. 81 Operation *op = ::new (rawMem) Operation( 82 location, name, numResults, numSuccessors, numRegions, 83 attributes.getDictionary(location.getContext()), needsOperandStorage); 84 85 assert((numSuccessors == 0 || op->mightHaveTrait<OpTrait::IsTerminator>()) && 86 "unexpected successors in a non-terminator operation"); 87 88 // Initialize the results. 89 auto resultTypeIt = resultTypes.begin(); 90 for (unsigned i = 0; i < numInlineResults; ++i, ++resultTypeIt) 91 new (op->getInlineOpResult(i)) detail::InlineOpResult(*resultTypeIt, i); 92 for (unsigned i = 0; i < numTrailingResults; ++i, ++resultTypeIt) { 93 new (op->getOutOfLineOpResult(i)) 94 detail::OutOfLineOpResult(*resultTypeIt, i); 95 } 96 97 // Initialize the regions. 98 for (unsigned i = 0; i != numRegions; ++i) 99 new (&op->getRegion(i)) Region(op); 100 101 // Initialize the operands. 102 if (needsOperandStorage) { 103 new (&op->getOperandStorage()) detail::OperandStorage( 104 op, op->getTrailingObjects<OpOperand>(), operands); 105 } 106 107 // Initialize the successors. 108 auto blockOperands = op->getBlockOperands(); 109 for (unsigned i = 0; i != numSuccessors; ++i) 110 new (&blockOperands[i]) BlockOperand(op, successors[i]); 111 112 return op; 113 } 114 115 Operation::Operation(Location location, OperationName name, unsigned numResults, 116 unsigned numSuccessors, unsigned numRegions, 117 DictionaryAttr attributes, bool hasOperandStorage) 118 : location(location), numResults(numResults), numSuccs(numSuccessors), 119 numRegions(numRegions), hasOperandStorage(hasOperandStorage), name(name), 120 attrs(attributes) { 121 assert(attributes && "unexpected null attribute dictionary"); 122 #ifndef NDEBUG 123 if (!getDialect() && !getContext()->allowsUnregisteredDialects()) 124 llvm::report_fatal_error( 125 name.getStringRef() + 126 " created with unregistered dialect. If this is intended, please call " 127 "allowUnregisteredDialects() on the MLIRContext, or use " 128 "-allow-unregistered-dialect with the MLIR tool used."); 129 #endif 130 } 131 132 // Operations are deleted through the destroy() member because they are 133 // allocated via malloc. 134 Operation::~Operation() { 135 assert(block == nullptr && "operation destroyed but still in a block"); 136 #ifndef NDEBUG 137 if (!use_empty()) { 138 { 139 InFlightDiagnostic diag = 140 emitOpError("operation destroyed but still has uses"); 141 for (Operation *user : getUsers()) 142 diag.attachNote(user->getLoc()) << "- use: " << *user << "\n"; 143 } 144 llvm::report_fatal_error("operation destroyed but still has uses"); 145 } 146 #endif 147 // Explicitly run the destructors for the operands. 148 if (hasOperandStorage) 149 getOperandStorage().~OperandStorage(); 150 151 // Explicitly run the destructors for the successors. 152 for (auto &successor : getBlockOperands()) 153 successor.~BlockOperand(); 154 155 // Explicitly destroy the regions. 156 for (auto ®ion : getRegions()) 157 region.~Region(); 158 } 159 160 /// Destroy this operation or one of its subclasses. 161 void Operation::destroy() { 162 // Operations may have additional prefixed allocation, which needs to be 163 // accounted for here when computing the address to free. 164 char *rawMem = reinterpret_cast<char *>(this) - 165 llvm::alignTo(prefixAllocSize(), alignof(Operation)); 166 this->~Operation(); 167 free(rawMem); 168 } 169 170 /// Return true if this operation is a proper ancestor of the `other` 171 /// operation. 172 bool Operation::isProperAncestor(Operation *other) { 173 while ((other = other->getParentOp())) 174 if (this == other) 175 return true; 176 return false; 177 } 178 179 /// Replace any uses of 'from' with 'to' within this operation. 180 void Operation::replaceUsesOfWith(Value from, Value to) { 181 if (from == to) 182 return; 183 for (auto &operand : getOpOperands()) 184 if (operand.get() == from) 185 operand.set(to); 186 } 187 188 /// Replace the current operands of this operation with the ones provided in 189 /// 'operands'. 190 void Operation::setOperands(ValueRange operands) { 191 if (LLVM_LIKELY(hasOperandStorage)) 192 return getOperandStorage().setOperands(this, operands); 193 assert(operands.empty() && "setting operands without an operand storage"); 194 } 195 196 /// Replace the operands beginning at 'start' and ending at 'start' + 'length' 197 /// with the ones provided in 'operands'. 'operands' may be smaller or larger 198 /// than the range pointed to by 'start'+'length'. 199 void Operation::setOperands(unsigned start, unsigned length, 200 ValueRange operands) { 201 assert((start + length) <= getNumOperands() && 202 "invalid operand range specified"); 203 if (LLVM_LIKELY(hasOperandStorage)) 204 return getOperandStorage().setOperands(this, start, length, operands); 205 assert(operands.empty() && "setting operands without an operand storage"); 206 } 207 208 /// Insert the given operands into the operand list at the given 'index'. 209 void Operation::insertOperands(unsigned index, ValueRange operands) { 210 if (LLVM_LIKELY(hasOperandStorage)) 211 return setOperands(index, /*length=*/0, operands); 212 assert(operands.empty() && "inserting operands without an operand storage"); 213 } 214 215 //===----------------------------------------------------------------------===// 216 // Diagnostics 217 //===----------------------------------------------------------------------===// 218 219 /// Emit an error about fatal conditions with this operation, reporting up to 220 /// any diagnostic handlers that may be listening. 221 InFlightDiagnostic Operation::emitError(const Twine &message) { 222 InFlightDiagnostic diag = mlir::emitError(getLoc(), message); 223 if (getContext()->shouldPrintOpOnDiagnostic()) { 224 diag.attachNote(getLoc()) 225 .append("see current operation: ") 226 .appendOp(*this, OpPrintingFlags().printGenericOpForm()); 227 } 228 return diag; 229 } 230 231 /// Emit a warning about this operation, reporting up to any diagnostic 232 /// handlers that may be listening. 233 InFlightDiagnostic Operation::emitWarning(const Twine &message) { 234 InFlightDiagnostic diag = mlir::emitWarning(getLoc(), message); 235 if (getContext()->shouldPrintOpOnDiagnostic()) 236 diag.attachNote(getLoc()) << "see current operation: " << *this; 237 return diag; 238 } 239 240 /// Emit a remark about this operation, reporting up to any diagnostic 241 /// handlers that may be listening. 242 InFlightDiagnostic Operation::emitRemark(const Twine &message) { 243 InFlightDiagnostic diag = mlir::emitRemark(getLoc(), message); 244 if (getContext()->shouldPrintOpOnDiagnostic()) 245 diag.attachNote(getLoc()) << "see current operation: " << *this; 246 return diag; 247 } 248 249 //===----------------------------------------------------------------------===// 250 // Operation Ordering 251 //===----------------------------------------------------------------------===// 252 253 constexpr unsigned Operation::kInvalidOrderIdx; 254 constexpr unsigned Operation::kOrderStride; 255 256 /// Given an operation 'other' that is within the same parent block, return 257 /// whether the current operation is before 'other' in the operation list 258 /// of the parent block. 259 /// Note: This function has an average complexity of O(1), but worst case may 260 /// take O(N) where N is the number of operations within the parent block. 261 bool Operation::isBeforeInBlock(Operation *other) { 262 assert(block && "Operations without parent blocks have no order."); 263 assert(other && other->block == block && 264 "Expected other operation to have the same parent block."); 265 // If the order of the block is already invalid, directly recompute the 266 // parent. 267 if (!block->isOpOrderValid()) { 268 block->recomputeOpOrder(); 269 } else { 270 // Update the order either operation if necessary. 271 updateOrderIfNecessary(); 272 other->updateOrderIfNecessary(); 273 } 274 275 return orderIndex < other->orderIndex; 276 } 277 278 /// Update the order index of this operation of this operation if necessary, 279 /// potentially recomputing the order of the parent block. 280 void Operation::updateOrderIfNecessary() { 281 assert(block && "expected valid parent"); 282 283 // If the order is valid for this operation there is nothing to do. 284 if (hasValidOrder()) 285 return; 286 Operation *blockFront = &block->front(); 287 Operation *blockBack = &block->back(); 288 289 // This method is expected to only be invoked on blocks with more than one 290 // operation. 291 assert(blockFront != blockBack && "expected more than one operation"); 292 293 // If the operation is at the end of the block. 294 if (this == blockBack) { 295 Operation *prevNode = getPrevNode(); 296 if (!prevNode->hasValidOrder()) 297 return block->recomputeOpOrder(); 298 299 // Add the stride to the previous operation. 300 orderIndex = prevNode->orderIndex + kOrderStride; 301 return; 302 } 303 304 // If this is the first operation try to use the next operation to compute the 305 // ordering. 306 if (this == blockFront) { 307 Operation *nextNode = getNextNode(); 308 if (!nextNode->hasValidOrder()) 309 return block->recomputeOpOrder(); 310 // There is no order to give this operation. 311 if (nextNode->orderIndex == 0) 312 return block->recomputeOpOrder(); 313 314 // If we can't use the stride, just take the middle value left. This is safe 315 // because we know there is at least one valid index to assign to. 316 if (nextNode->orderIndex <= kOrderStride) 317 orderIndex = (nextNode->orderIndex / 2); 318 else 319 orderIndex = kOrderStride; 320 return; 321 } 322 323 // Otherwise, this operation is between two others. Place this operation in 324 // the middle of the previous and next if possible. 325 Operation *prevNode = getPrevNode(), *nextNode = getNextNode(); 326 if (!prevNode->hasValidOrder() || !nextNode->hasValidOrder()) 327 return block->recomputeOpOrder(); 328 unsigned prevOrder = prevNode->orderIndex, nextOrder = nextNode->orderIndex; 329 330 // Check to see if there is a valid order between the two. 331 if (prevOrder + 1 == nextOrder) 332 return block->recomputeOpOrder(); 333 orderIndex = prevOrder + ((nextOrder - prevOrder) / 2); 334 } 335 336 //===----------------------------------------------------------------------===// 337 // ilist_traits for Operation 338 //===----------------------------------------------------------------------===// 339 340 auto llvm::ilist_detail::SpecificNodeAccess< 341 typename llvm::ilist_detail::compute_node_options< 342 ::mlir::Operation>::type>::getNodePtr(pointer n) -> node_type * { 343 return NodeAccess::getNodePtr<OptionsT>(n); 344 } 345 346 auto llvm::ilist_detail::SpecificNodeAccess< 347 typename llvm::ilist_detail::compute_node_options< 348 ::mlir::Operation>::type>::getNodePtr(const_pointer n) 349 -> const node_type * { 350 return NodeAccess::getNodePtr<OptionsT>(n); 351 } 352 353 auto llvm::ilist_detail::SpecificNodeAccess< 354 typename llvm::ilist_detail::compute_node_options< 355 ::mlir::Operation>::type>::getValuePtr(node_type *n) -> pointer { 356 return NodeAccess::getValuePtr<OptionsT>(n); 357 } 358 359 auto llvm::ilist_detail::SpecificNodeAccess< 360 typename llvm::ilist_detail::compute_node_options< 361 ::mlir::Operation>::type>::getValuePtr(const node_type *n) 362 -> const_pointer { 363 return NodeAccess::getValuePtr<OptionsT>(n); 364 } 365 366 void llvm::ilist_traits<::mlir::Operation>::deleteNode(Operation *op) { 367 op->destroy(); 368 } 369 370 Block *llvm::ilist_traits<::mlir::Operation>::getContainingBlock() { 371 size_t offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr)))); 372 iplist<Operation> *anchor(static_cast<iplist<Operation> *>(this)); 373 return reinterpret_cast<Block *>(reinterpret_cast<char *>(anchor) - offset); 374 } 375 376 /// This is a trait method invoked when an operation is added to a block. We 377 /// keep the block pointer up to date. 378 void llvm::ilist_traits<::mlir::Operation>::addNodeToList(Operation *op) { 379 assert(!op->getBlock() && "already in an operation block!"); 380 op->block = getContainingBlock(); 381 382 // Invalidate the order on the operation. 383 op->orderIndex = Operation::kInvalidOrderIdx; 384 } 385 386 /// This is a trait method invoked when an operation is removed from a block. 387 /// We keep the block pointer up to date. 388 void llvm::ilist_traits<::mlir::Operation>::removeNodeFromList(Operation *op) { 389 assert(op->block && "not already in an operation block!"); 390 op->block = nullptr; 391 } 392 393 /// This is a trait method invoked when an operation is moved from one block 394 /// to another. We keep the block pointer up to date. 395 void llvm::ilist_traits<::mlir::Operation>::transferNodesFromList( 396 ilist_traits<Operation> &otherList, op_iterator first, op_iterator last) { 397 Block *curParent = getContainingBlock(); 398 399 // Invalidate the ordering of the parent block. 400 curParent->invalidateOpOrder(); 401 402 // If we are transferring operations within the same block, the block 403 // pointer doesn't need to be updated. 404 if (curParent == otherList.getContainingBlock()) 405 return; 406 407 // Update the 'block' member of each operation. 408 for (; first != last; ++first) 409 first->block = curParent; 410 } 411 412 /// Remove this operation (and its descendants) from its Block and delete 413 /// all of them. 414 void Operation::erase() { 415 if (auto *parent = getBlock()) 416 parent->getOperations().erase(this); 417 else 418 destroy(); 419 } 420 421 /// Remove the operation from its parent block, but don't delete it. 422 void Operation::remove() { 423 if (Block *parent = getBlock()) 424 parent->getOperations().remove(this); 425 } 426 427 /// Unlink this operation from its current block and insert it right before 428 /// `existingOp` which may be in the same or another block in the same 429 /// function. 430 void Operation::moveBefore(Operation *existingOp) { 431 moveBefore(existingOp->getBlock(), existingOp->getIterator()); 432 } 433 434 /// Unlink this operation from its current basic block and insert it right 435 /// before `iterator` in the specified basic block. 436 void Operation::moveBefore(Block *block, 437 llvm::iplist<Operation>::iterator iterator) { 438 block->getOperations().splice(iterator, getBlock()->getOperations(), 439 getIterator()); 440 } 441 442 /// Unlink this operation from its current block and insert it right after 443 /// `existingOp` which may be in the same or another block in the same function. 444 void Operation::moveAfter(Operation *existingOp) { 445 moveAfter(existingOp->getBlock(), existingOp->getIterator()); 446 } 447 448 /// Unlink this operation from its current block and insert it right after 449 /// `iterator` in the specified block. 450 void Operation::moveAfter(Block *block, 451 llvm::iplist<Operation>::iterator iterator) { 452 assert(iterator != block->end() && "cannot move after end of block"); 453 moveBefore(block, std::next(iterator)); 454 } 455 456 /// This drops all operand uses from this operation, which is an essential 457 /// step in breaking cyclic dependences between references when they are to 458 /// be deleted. 459 void Operation::dropAllReferences() { 460 for (auto &op : getOpOperands()) 461 op.drop(); 462 463 for (auto ®ion : getRegions()) 464 region.dropAllReferences(); 465 466 for (auto &dest : getBlockOperands()) 467 dest.drop(); 468 } 469 470 /// This drops all uses of any values defined by this operation or its nested 471 /// regions, wherever they are located. 472 void Operation::dropAllDefinedValueUses() { 473 dropAllUses(); 474 475 for (auto ®ion : getRegions()) 476 for (auto &block : region) 477 block.dropAllDefinedValueUses(); 478 } 479 480 void Operation::setSuccessor(Block *block, unsigned index) { 481 assert(index < getNumSuccessors()); 482 getBlockOperands()[index].set(block); 483 } 484 485 /// Attempt to fold this operation using the Op's registered foldHook. 486 LogicalResult Operation::fold(ArrayRef<Attribute> operands, 487 SmallVectorImpl<OpFoldResult> &results) { 488 // If we have a registered operation definition matching this one, use it to 489 // try to constant fold the operation. 490 Optional<RegisteredOperationName> info = getRegisteredInfo(); 491 if (info && succeeded(info->foldHook(this, operands, results))) 492 return success(); 493 494 // Otherwise, fall back on the dialect hook to handle it. 495 Dialect *dialect = getDialect(); 496 if (!dialect) 497 return failure(); 498 499 auto *interface = dyn_cast<DialectFoldInterface>(dialect); 500 if (!interface) 501 return failure(); 502 503 return interface->fold(this, operands, results); 504 } 505 506 /// Emit an error with the op name prefixed, like "'dim' op " which is 507 /// convenient for verifiers. 508 InFlightDiagnostic Operation::emitOpError(const Twine &message) { 509 return emitError() << "'" << getName() << "' op " << message; 510 } 511 512 //===----------------------------------------------------------------------===// 513 // Operation Cloning 514 //===----------------------------------------------------------------------===// 515 516 Operation::CloneOptions::CloneOptions() 517 : cloneRegionsFlag(false), cloneOperandsFlag(false) {} 518 519 Operation::CloneOptions::CloneOptions(bool cloneRegions, bool cloneOperands) 520 : cloneRegionsFlag(cloneRegions), cloneOperandsFlag(cloneOperands) {} 521 522 Operation::CloneOptions Operation::CloneOptions::all() { 523 return CloneOptions().cloneRegions().cloneOperands(); 524 } 525 526 Operation::CloneOptions &Operation::CloneOptions::cloneRegions(bool enable) { 527 cloneRegionsFlag = enable; 528 return *this; 529 } 530 531 Operation::CloneOptions &Operation::CloneOptions::cloneOperands(bool enable) { 532 cloneOperandsFlag = enable; 533 return *this; 534 } 535 536 /// Create a deep copy of this operation but keep the operation regions empty. 537 /// Operands are remapped using `mapper` (if present), and `mapper` is updated 538 /// to contain the results. The `mapResults` flag specifies whether the results 539 /// of the cloned operation should be added to the map. 540 Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper) { 541 return clone(mapper, CloneOptions::all().cloneRegions(false)); 542 } 543 544 Operation *Operation::cloneWithoutRegions() { 545 BlockAndValueMapping mapper; 546 return cloneWithoutRegions(mapper); 547 } 548 549 /// Create a deep copy of this operation, remapping any operands that use 550 /// values outside of the operation using the map that is provided (leaving 551 /// them alone if no entry is present). Replaces references to cloned 552 /// sub-operations to the corresponding operation that is copied, and adds 553 /// those mappings to the map. 554 Operation *Operation::clone(BlockAndValueMapping &mapper, 555 CloneOptions options) { 556 SmallVector<Value, 8> operands; 557 SmallVector<Block *, 2> successors; 558 559 // Remap the operands. 560 if (options.shouldCloneOperands()) { 561 operands.reserve(getNumOperands()); 562 for (auto opValue : getOperands()) 563 operands.push_back(mapper.lookupOrDefault(opValue)); 564 } 565 566 // Remap the successors. 567 successors.reserve(getNumSuccessors()); 568 for (Block *successor : getSuccessors()) 569 successors.push_back(mapper.lookupOrDefault(successor)); 570 571 // Create the new operation. 572 auto *newOp = create(getLoc(), getName(), getResultTypes(), operands, attrs, 573 successors, getNumRegions()); 574 575 // Clone the regions. 576 if (options.shouldCloneRegions()) { 577 for (unsigned i = 0; i != numRegions; ++i) 578 getRegion(i).cloneInto(&newOp->getRegion(i), mapper); 579 } 580 581 // Remember the mapping of any results. 582 for (unsigned i = 0, e = getNumResults(); i != e; ++i) 583 mapper.map(getResult(i), newOp->getResult(i)); 584 585 return newOp; 586 } 587 588 Operation *Operation::clone(CloneOptions options) { 589 BlockAndValueMapping mapper; 590 return clone(mapper, options); 591 } 592 593 //===----------------------------------------------------------------------===// 594 // OpState trait class. 595 //===----------------------------------------------------------------------===// 596 597 // The fallback for the parser is to try for a dialect operation parser. 598 // Otherwise, reject the custom assembly form. 599 ParseResult OpState::parse(OpAsmParser &parser, OperationState &result) { 600 if (auto parseFn = result.name.getDialect()->getParseOperationHook( 601 result.name.getStringRef())) 602 return (*parseFn)(parser, result); 603 return parser.emitError(parser.getNameLoc(), "has no custom assembly form"); 604 } 605 606 // The fallback for the printer is to try for a dialect operation printer. 607 // Otherwise, it prints the generic form. 608 void OpState::print(Operation *op, OpAsmPrinter &p, StringRef defaultDialect) { 609 if (auto printFn = op->getDialect()->getOperationPrinter(op)) { 610 printOpName(op, p, defaultDialect); 611 printFn(op, p); 612 } else { 613 p.printGenericOp(op); 614 } 615 } 616 617 /// Print an operation name, eliding the dialect prefix if necessary and doesn't 618 /// lead to ambiguities. 619 void OpState::printOpName(Operation *op, OpAsmPrinter &p, 620 StringRef defaultDialect) { 621 StringRef name = op->getName().getStringRef(); 622 if (name.startswith((defaultDialect + ".").str()) && name.count('.') == 1) 623 name = name.drop_front(defaultDialect.size() + 1); 624 p.getStream() << name; 625 } 626 627 /// Emit an error about fatal conditions with this operation, reporting up to 628 /// any diagnostic handlers that may be listening. 629 InFlightDiagnostic OpState::emitError(const Twine &message) { 630 return getOperation()->emitError(message); 631 } 632 633 /// Emit an error with the op name prefixed, like "'dim' op " which is 634 /// convenient for verifiers. 635 InFlightDiagnostic OpState::emitOpError(const Twine &message) { 636 return getOperation()->emitOpError(message); 637 } 638 639 /// Emit a warning about this operation, reporting up to any diagnostic 640 /// handlers that may be listening. 641 InFlightDiagnostic OpState::emitWarning(const Twine &message) { 642 return getOperation()->emitWarning(message); 643 } 644 645 /// Emit a remark about this operation, reporting up to any diagnostic 646 /// handlers that may be listening. 647 InFlightDiagnostic OpState::emitRemark(const Twine &message) { 648 return getOperation()->emitRemark(message); 649 } 650 651 //===----------------------------------------------------------------------===// 652 // Op Trait implementations 653 //===----------------------------------------------------------------------===// 654 655 OpFoldResult OpTrait::impl::foldIdempotent(Operation *op) { 656 if (op->getNumOperands() == 1) { 657 auto *argumentOp = op->getOperand(0).getDefiningOp(); 658 if (argumentOp && op->getName() == argumentOp->getName()) { 659 // Replace the outer operation output with the inner operation. 660 return op->getOperand(0); 661 } 662 } else if (op->getOperand(0) == op->getOperand(1)) { 663 return op->getOperand(0); 664 } 665 666 return {}; 667 } 668 669 OpFoldResult OpTrait::impl::foldInvolution(Operation *op) { 670 auto *argumentOp = op->getOperand(0).getDefiningOp(); 671 if (argumentOp && op->getName() == argumentOp->getName()) { 672 // Replace the outer involutions output with inner's input. 673 return argumentOp->getOperand(0); 674 } 675 676 return {}; 677 } 678 679 LogicalResult OpTrait::impl::verifyZeroOperands(Operation *op) { 680 if (op->getNumOperands() != 0) 681 return op->emitOpError() << "requires zero operands"; 682 return success(); 683 } 684 685 LogicalResult OpTrait::impl::verifyOneOperand(Operation *op) { 686 if (op->getNumOperands() != 1) 687 return op->emitOpError() << "requires a single operand"; 688 return success(); 689 } 690 691 LogicalResult OpTrait::impl::verifyNOperands(Operation *op, 692 unsigned numOperands) { 693 if (op->getNumOperands() != numOperands) { 694 return op->emitOpError() << "expected " << numOperands 695 << " operands, but found " << op->getNumOperands(); 696 } 697 return success(); 698 } 699 700 LogicalResult OpTrait::impl::verifyAtLeastNOperands(Operation *op, 701 unsigned numOperands) { 702 if (op->getNumOperands() < numOperands) 703 return op->emitOpError() 704 << "expected " << numOperands << " or more operands, but found " 705 << op->getNumOperands(); 706 return success(); 707 } 708 709 /// If this is a vector type, or a tensor type, return the scalar element type 710 /// that it is built around, otherwise return the type unmodified. 711 static Type getTensorOrVectorElementType(Type type) { 712 if (auto vec = type.dyn_cast<VectorType>()) 713 return vec.getElementType(); 714 715 // Look through tensor<vector<...>> to find the underlying element type. 716 if (auto tensor = type.dyn_cast<TensorType>()) 717 return getTensorOrVectorElementType(tensor.getElementType()); 718 return type; 719 } 720 721 LogicalResult OpTrait::impl::verifyIsIdempotent(Operation *op) { 722 // FIXME: Add back check for no side effects on operation. 723 // Currently adding it would cause the shared library build 724 // to fail since there would be a dependency of IR on SideEffectInterfaces 725 // which is cyclical. 726 return success(); 727 } 728 729 LogicalResult OpTrait::impl::verifyIsInvolution(Operation *op) { 730 // FIXME: Add back check for no side effects on operation. 731 // Currently adding it would cause the shared library build 732 // to fail since there would be a dependency of IR on SideEffectInterfaces 733 // which is cyclical. 734 return success(); 735 } 736 737 LogicalResult 738 OpTrait::impl::verifyOperandsAreSignlessIntegerLike(Operation *op) { 739 for (auto opType : op->getOperandTypes()) { 740 auto type = getTensorOrVectorElementType(opType); 741 if (!type.isSignlessIntOrIndex()) 742 return op->emitOpError() << "requires an integer or index type"; 743 } 744 return success(); 745 } 746 747 LogicalResult OpTrait::impl::verifyOperandsAreFloatLike(Operation *op) { 748 for (auto opType : op->getOperandTypes()) { 749 auto type = getTensorOrVectorElementType(opType); 750 if (!type.isa<FloatType>()) 751 return op->emitOpError("requires a float type"); 752 } 753 return success(); 754 } 755 756 LogicalResult OpTrait::impl::verifySameTypeOperands(Operation *op) { 757 // Zero or one operand always have the "same" type. 758 unsigned nOperands = op->getNumOperands(); 759 if (nOperands < 2) 760 return success(); 761 762 auto type = op->getOperand(0).getType(); 763 for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) 764 if (opType != type) 765 return op->emitOpError() << "requires all operands to have the same type"; 766 return success(); 767 } 768 769 LogicalResult OpTrait::impl::verifyZeroRegions(Operation *op) { 770 if (op->getNumRegions() != 0) 771 return op->emitOpError() << "requires zero regions"; 772 return success(); 773 } 774 775 LogicalResult OpTrait::impl::verifyOneRegion(Operation *op) { 776 if (op->getNumRegions() != 1) 777 return op->emitOpError() << "requires one region"; 778 return success(); 779 } 780 781 LogicalResult OpTrait::impl::verifyNRegions(Operation *op, 782 unsigned numRegions) { 783 if (op->getNumRegions() != numRegions) 784 return op->emitOpError() << "expected " << numRegions << " regions"; 785 return success(); 786 } 787 788 LogicalResult OpTrait::impl::verifyAtLeastNRegions(Operation *op, 789 unsigned numRegions) { 790 if (op->getNumRegions() < numRegions) 791 return op->emitOpError() << "expected " << numRegions << " or more regions"; 792 return success(); 793 } 794 795 LogicalResult OpTrait::impl::verifyZeroResults(Operation *op) { 796 if (op->getNumResults() != 0) 797 return op->emitOpError() << "requires zero results"; 798 return success(); 799 } 800 801 LogicalResult OpTrait::impl::verifyOneResult(Operation *op) { 802 if (op->getNumResults() != 1) 803 return op->emitOpError() << "requires one result"; 804 return success(); 805 } 806 807 LogicalResult OpTrait::impl::verifyNResults(Operation *op, 808 unsigned numOperands) { 809 if (op->getNumResults() != numOperands) 810 return op->emitOpError() << "expected " << numOperands << " results"; 811 return success(); 812 } 813 814 LogicalResult OpTrait::impl::verifyAtLeastNResults(Operation *op, 815 unsigned numOperands) { 816 if (op->getNumResults() < numOperands) 817 return op->emitOpError() 818 << "expected " << numOperands << " or more results"; 819 return success(); 820 } 821 822 LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) { 823 if (failed(verifyAtLeastNOperands(op, 1))) 824 return failure(); 825 826 if (failed(verifyCompatibleShapes(op->getOperandTypes()))) 827 return op->emitOpError() << "requires the same shape for all operands"; 828 829 return success(); 830 } 831 832 LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) { 833 if (failed(verifyAtLeastNOperands(op, 1)) || 834 failed(verifyAtLeastNResults(op, 1))) 835 return failure(); 836 837 SmallVector<Type, 8> types(op->getOperandTypes()); 838 types.append(llvm::to_vector<4>(op->getResultTypes())); 839 840 if (failed(verifyCompatibleShapes(types))) 841 return op->emitOpError() 842 << "requires the same shape for all operands and results"; 843 844 return success(); 845 } 846 847 LogicalResult OpTrait::impl::verifySameOperandsElementType(Operation *op) { 848 if (failed(verifyAtLeastNOperands(op, 1))) 849 return failure(); 850 auto elementType = getElementTypeOrSelf(op->getOperand(0)); 851 852 for (auto operand : llvm::drop_begin(op->getOperands(), 1)) { 853 if (getElementTypeOrSelf(operand) != elementType) 854 return op->emitOpError("requires the same element type for all operands"); 855 } 856 857 return success(); 858 } 859 860 LogicalResult 861 OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) { 862 if (failed(verifyAtLeastNOperands(op, 1)) || 863 failed(verifyAtLeastNResults(op, 1))) 864 return failure(); 865 866 auto elementType = getElementTypeOrSelf(op->getResult(0)); 867 868 // Verify result element type matches first result's element type. 869 for (auto result : llvm::drop_begin(op->getResults(), 1)) { 870 if (getElementTypeOrSelf(result) != elementType) 871 return op->emitOpError( 872 "requires the same element type for all operands and results"); 873 } 874 875 // Verify operand's element type matches first result's element type. 876 for (auto operand : op->getOperands()) { 877 if (getElementTypeOrSelf(operand) != elementType) 878 return op->emitOpError( 879 "requires the same element type for all operands and results"); 880 } 881 882 return success(); 883 } 884 885 LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) { 886 if (failed(verifyAtLeastNOperands(op, 1)) || 887 failed(verifyAtLeastNResults(op, 1))) 888 return failure(); 889 890 auto type = op->getResult(0).getType(); 891 auto elementType = getElementTypeOrSelf(type); 892 for (auto resultType : llvm::drop_begin(op->getResultTypes())) { 893 if (getElementTypeOrSelf(resultType) != elementType || 894 failed(verifyCompatibleShape(resultType, type))) 895 return op->emitOpError() 896 << "requires the same type for all operands and results"; 897 } 898 for (auto opType : op->getOperandTypes()) { 899 if (getElementTypeOrSelf(opType) != elementType || 900 failed(verifyCompatibleShape(opType, type))) 901 return op->emitOpError() 902 << "requires the same type for all operands and results"; 903 } 904 return success(); 905 } 906 907 LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) { 908 Block *block = op->getBlock(); 909 // Verify that the operation is at the end of the respective parent block. 910 if (!block || &block->back() != op) 911 return op->emitOpError("must be the last operation in the parent block"); 912 return success(); 913 } 914 915 static LogicalResult verifyTerminatorSuccessors(Operation *op) { 916 auto *parent = op->getParentRegion(); 917 918 // Verify that the operands lines up with the BB arguments in the successor. 919 for (Block *succ : op->getSuccessors()) 920 if (succ->getParent() != parent) 921 return op->emitError("reference to block defined in another region"); 922 return success(); 923 } 924 925 LogicalResult OpTrait::impl::verifyZeroSuccessors(Operation *op) { 926 if (op->getNumSuccessors() != 0) { 927 return op->emitOpError("requires 0 successors but found ") 928 << op->getNumSuccessors(); 929 } 930 return success(); 931 } 932 933 LogicalResult OpTrait::impl::verifyOneSuccessor(Operation *op) { 934 if (op->getNumSuccessors() != 1) { 935 return op->emitOpError("requires 1 successor but found ") 936 << op->getNumSuccessors(); 937 } 938 return verifyTerminatorSuccessors(op); 939 } 940 LogicalResult OpTrait::impl::verifyNSuccessors(Operation *op, 941 unsigned numSuccessors) { 942 if (op->getNumSuccessors() != numSuccessors) { 943 return op->emitOpError("requires ") 944 << numSuccessors << " successors but found " 945 << op->getNumSuccessors(); 946 } 947 return verifyTerminatorSuccessors(op); 948 } 949 LogicalResult OpTrait::impl::verifyAtLeastNSuccessors(Operation *op, 950 unsigned numSuccessors) { 951 if (op->getNumSuccessors() < numSuccessors) { 952 return op->emitOpError("requires at least ") 953 << numSuccessors << " successors but found " 954 << op->getNumSuccessors(); 955 } 956 return verifyTerminatorSuccessors(op); 957 } 958 959 LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) { 960 for (auto resultType : op->getResultTypes()) { 961 auto elementType = getTensorOrVectorElementType(resultType); 962 bool isBoolType = elementType.isInteger(1); 963 if (!isBoolType) 964 return op->emitOpError() << "requires a bool result type"; 965 } 966 967 return success(); 968 } 969 970 LogicalResult OpTrait::impl::verifyResultsAreFloatLike(Operation *op) { 971 for (auto resultType : op->getResultTypes()) 972 if (!getTensorOrVectorElementType(resultType).isa<FloatType>()) 973 return op->emitOpError() << "requires a floating point type"; 974 975 return success(); 976 } 977 978 LogicalResult 979 OpTrait::impl::verifyResultsAreSignlessIntegerLike(Operation *op) { 980 for (auto resultType : op->getResultTypes()) 981 if (!getTensorOrVectorElementType(resultType).isSignlessIntOrIndex()) 982 return op->emitOpError() << "requires an integer or index type"; 983 return success(); 984 } 985 986 LogicalResult OpTrait::impl::verifyValueSizeAttr(Operation *op, 987 StringRef attrName, 988 StringRef valueGroupName, 989 size_t expectedCount) { 990 auto sizeAttr = op->getAttrOfType<DenseIntElementsAttr>(attrName); 991 if (!sizeAttr) 992 return op->emitOpError("requires 1D i32 elements attribute '") 993 << attrName << "'"; 994 995 auto sizeAttrType = sizeAttr.getType(); 996 if (sizeAttrType.getRank() != 1 || 997 !sizeAttrType.getElementType().isInteger(32)) 998 return op->emitOpError("requires 1D i32 elements attribute '") 999 << attrName << "'"; 1000 1001 if (llvm::any_of(sizeAttr.getValues<APInt>(), [](const APInt &element) { 1002 return !element.isNonNegative(); 1003 })) 1004 return op->emitOpError("'") 1005 << attrName << "' attribute cannot have negative elements"; 1006 1007 size_t totalCount = std::accumulate( 1008 sizeAttr.begin(), sizeAttr.end(), 0, 1009 [](unsigned all, const APInt &one) { return all + one.getZExtValue(); }); 1010 1011 if (totalCount != expectedCount) 1012 return op->emitOpError() 1013 << valueGroupName << " count (" << expectedCount 1014 << ") does not match with the total size (" << totalCount 1015 << ") specified in attribute '" << attrName << "'"; 1016 return success(); 1017 } 1018 1019 LogicalResult OpTrait::impl::verifyOperandSizeAttr(Operation *op, 1020 StringRef attrName) { 1021 return verifyValueSizeAttr(op, attrName, "operand", op->getNumOperands()); 1022 } 1023 1024 LogicalResult OpTrait::impl::verifyResultSizeAttr(Operation *op, 1025 StringRef attrName) { 1026 return verifyValueSizeAttr(op, attrName, "result", op->getNumResults()); 1027 } 1028 1029 LogicalResult OpTrait::impl::verifyNoRegionArguments(Operation *op) { 1030 for (Region ®ion : op->getRegions()) { 1031 if (region.empty()) 1032 continue; 1033 1034 if (region.getNumArguments() != 0) { 1035 if (op->getNumRegions() > 1) 1036 return op->emitOpError("region #") 1037 << region.getRegionNumber() << " should have no arguments"; 1038 return op->emitOpError("region should have no arguments"); 1039 } 1040 } 1041 return success(); 1042 } 1043 1044 LogicalResult OpTrait::impl::verifyElementwise(Operation *op) { 1045 auto isMappableType = [](Type type) { 1046 return type.isa<VectorType, TensorType>(); 1047 }; 1048 auto resultMappableTypes = llvm::to_vector<1>( 1049 llvm::make_filter_range(op->getResultTypes(), isMappableType)); 1050 auto operandMappableTypes = llvm::to_vector<2>( 1051 llvm::make_filter_range(op->getOperandTypes(), isMappableType)); 1052 1053 // If the op only has scalar operand/result types, then we have nothing to 1054 // check. 1055 if (resultMappableTypes.empty() && operandMappableTypes.empty()) 1056 return success(); 1057 1058 if (!resultMappableTypes.empty() && operandMappableTypes.empty()) 1059 return op->emitOpError("if a result is non-scalar, then at least one " 1060 "operand must be non-scalar"); 1061 1062 assert(!operandMappableTypes.empty()); 1063 1064 if (resultMappableTypes.empty()) 1065 return op->emitOpError("if an operand is non-scalar, then there must be at " 1066 "least one non-scalar result"); 1067 1068 if (resultMappableTypes.size() != op->getNumResults()) 1069 return op->emitOpError( 1070 "if an operand is non-scalar, then all results must be non-scalar"); 1071 1072 SmallVector<Type, 4> types = llvm::to_vector<2>( 1073 llvm::concat<Type>(operandMappableTypes, resultMappableTypes)); 1074 TypeID expectedBaseTy = types.front().getTypeID(); 1075 if (!llvm::all_of(types, 1076 [&](Type t) { return t.getTypeID() == expectedBaseTy; }) || 1077 failed(verifyCompatibleShapes(types))) { 1078 return op->emitOpError() << "all non-scalar operands/results must have the " 1079 "same shape and base type"; 1080 } 1081 1082 return success(); 1083 } 1084 1085 /// Check for any values used by operations regions attached to the 1086 /// specified "IsIsolatedFromAbove" operation defined outside of it. 1087 LogicalResult OpTrait::impl::verifyIsIsolatedFromAbove(Operation *isolatedOp) { 1088 assert(isolatedOp->hasTrait<OpTrait::IsIsolatedFromAbove>() && 1089 "Intended to check IsolatedFromAbove ops"); 1090 1091 // List of regions to analyze. Each region is processed independently, with 1092 // respect to the common `limit` region, so we can look at them in any order. 1093 // Therefore, use a simple vector and push/pop back the current region. 1094 SmallVector<Region *, 8> pendingRegions; 1095 for (auto ®ion : isolatedOp->getRegions()) { 1096 pendingRegions.push_back(®ion); 1097 1098 // Traverse all operations in the region. 1099 while (!pendingRegions.empty()) { 1100 for (Operation &op : pendingRegions.pop_back_val()->getOps()) { 1101 for (Value operand : op.getOperands()) { 1102 // Check that any value that is used by an operation is defined in the 1103 // same region as either an operation result. 1104 auto *operandRegion = operand.getParentRegion(); 1105 if (!operandRegion) 1106 return op.emitError("operation's operand is unlinked"); 1107 if (!region.isAncestor(operandRegion)) { 1108 return op.emitOpError("using value defined outside the region") 1109 .attachNote(isolatedOp->getLoc()) 1110 << "required by region isolation constraints"; 1111 } 1112 } 1113 1114 // Schedule any regions in the operation for further checking. Don't 1115 // recurse into other IsolatedFromAbove ops, because they will check 1116 // themselves. 1117 if (op.getNumRegions() && 1118 !op.hasTrait<OpTrait::IsIsolatedFromAbove>()) { 1119 for (Region &subRegion : op.getRegions()) 1120 pendingRegions.push_back(&subRegion); 1121 } 1122 } 1123 } 1124 } 1125 1126 return success(); 1127 } 1128 1129 bool OpTrait::hasElementwiseMappableTraits(Operation *op) { 1130 return op->hasTrait<Elementwise>() && op->hasTrait<Scalarizable>() && 1131 op->hasTrait<Vectorizable>() && op->hasTrait<Tensorizable>(); 1132 } 1133 1134 //===----------------------------------------------------------------------===// 1135 // CastOpInterface 1136 //===----------------------------------------------------------------------===// 1137 1138 /// Attempt to fold the given cast operation. 1139 LogicalResult 1140 impl::foldCastInterfaceOp(Operation *op, ArrayRef<Attribute> attrOperands, 1141 SmallVectorImpl<OpFoldResult> &foldResults) { 1142 OperandRange operands = op->getOperands(); 1143 if (operands.empty()) 1144 return failure(); 1145 ResultRange results = op->getResults(); 1146 1147 // Check for the case where the input and output types match 1-1. 1148 if (operands.getTypes() == results.getTypes()) { 1149 foldResults.append(operands.begin(), operands.end()); 1150 return success(); 1151 } 1152 1153 return failure(); 1154 } 1155 1156 /// Attempt to verify the given cast operation. 1157 LogicalResult impl::verifyCastInterfaceOp( 1158 Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible) { 1159 auto resultTypes = op->getResultTypes(); 1160 if (llvm::empty(resultTypes)) 1161 return op->emitOpError() 1162 << "expected at least one result for cast operation"; 1163 1164 auto operandTypes = op->getOperandTypes(); 1165 if (!areCastCompatible(operandTypes, resultTypes)) { 1166 InFlightDiagnostic diag = op->emitOpError("operand type"); 1167 if (llvm::empty(operandTypes)) 1168 diag << "s []"; 1169 else if (llvm::size(operandTypes) == 1) 1170 diag << " " << *operandTypes.begin(); 1171 else 1172 diag << "s " << operandTypes; 1173 return diag << " and result type" << (resultTypes.size() == 1 ? " " : "s ") 1174 << resultTypes << " are cast incompatible"; 1175 } 1176 1177 return success(); 1178 } 1179 1180 //===----------------------------------------------------------------------===// 1181 // Misc. utils 1182 //===----------------------------------------------------------------------===// 1183 1184 /// Insert an operation, generated by `buildTerminatorOp`, at the end of the 1185 /// region's only block if it does not have a terminator already. If the region 1186 /// is empty, insert a new block first. `buildTerminatorOp` should return the 1187 /// terminator operation to insert. 1188 void impl::ensureRegionTerminator( 1189 Region ®ion, OpBuilder &builder, Location loc, 1190 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { 1191 OpBuilder::InsertionGuard guard(builder); 1192 if (region.empty()) 1193 builder.createBlock(®ion); 1194 1195 Block &block = region.back(); 1196 if (!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>()) 1197 return; 1198 1199 builder.setInsertionPointToEnd(&block); 1200 builder.insert(buildTerminatorOp(builder, loc)); 1201 } 1202 1203 /// Create a simple OpBuilder and forward to the OpBuilder version of this 1204 /// function. 1205 void impl::ensureRegionTerminator( 1206 Region ®ion, Builder &builder, Location loc, 1207 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { 1208 OpBuilder opBuilder(builder.getContext()); 1209 ensureRegionTerminator(region, opBuilder, loc, buildTerminatorOp); 1210 } 1211