1 //===- Operation.cpp - Operation support code -----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Operation.h" 10 #include "mlir/IR/BlockAndValueMapping.h" 11 #include "mlir/IR/BuiltinTypes.h" 12 #include "mlir/IR/Dialect.h" 13 #include "mlir/IR/OpImplementation.h" 14 #include "mlir/IR/PatternMatch.h" 15 #include "mlir/IR/TypeUtilities.h" 16 #include "mlir/Interfaces/FoldInterfaces.h" 17 #include "llvm/ADT/StringExtras.h" 18 #include <numeric> 19 20 using namespace mlir; 21 22 //===----------------------------------------------------------------------===// 23 // Operation 24 //===----------------------------------------------------------------------===// 25 26 /// Create a new Operation with the specific fields. 27 Operation *Operation::create(Location location, OperationName name, 28 TypeRange resultTypes, ValueRange operands, 29 ArrayRef<NamedAttribute> attributes, 30 BlockRange successors, unsigned numRegions) { 31 return create(location, name, resultTypes, operands, 32 DictionaryAttr::get(location.getContext(), attributes), 33 successors, numRegions); 34 } 35 36 /// Create a new Operation from operation state. 37 Operation *Operation::create(const OperationState &state) { 38 return create(state.location, state.name, state.types, state.operands, 39 state.attributes.getDictionary(state.getContext()), 40 state.successors, state.regions); 41 } 42 43 /// Create a new Operation with the specific fields. 44 Operation *Operation::create(Location location, OperationName name, 45 TypeRange resultTypes, ValueRange operands, 46 DictionaryAttr attributes, BlockRange successors, 47 RegionRange regions) { 48 unsigned numRegions = regions.size(); 49 Operation *op = create(location, name, resultTypes, operands, attributes, 50 successors, numRegions); 51 for (unsigned i = 0; i < numRegions; ++i) 52 if (regions[i]) 53 op->getRegion(i).takeBody(*regions[i]); 54 return op; 55 } 56 57 /// Overload of create that takes an existing DictionaryAttr to avoid 58 /// unnecessarily uniquing a list of attributes. 59 Operation *Operation::create(Location location, OperationName name, 60 TypeRange resultTypes, ValueRange operands, 61 DictionaryAttr attributes, BlockRange successors, 62 unsigned numRegions) { 63 assert(llvm::all_of(resultTypes, [](Type t) { return t; }) && 64 "unexpected null result type"); 65 66 // We only need to allocate additional memory for a subset of results. 67 unsigned numTrailingResults = OpResult::getNumTrailing(resultTypes.size()); 68 unsigned numInlineResults = OpResult::getNumInline(resultTypes.size()); 69 unsigned numSuccessors = successors.size(); 70 unsigned numOperands = operands.size(); 71 unsigned numResults = resultTypes.size(); 72 73 // If the operation is known to have no operands, don't allocate an operand 74 // storage. 75 bool needsOperandStorage = 76 operands.empty() ? !name.hasTrait<OpTrait::ZeroOperands>() : true; 77 78 // Compute the byte size for the operation and the operand storage. This takes 79 // into account the size of the operation, its trailing objects, and its 80 // prefixed objects. 81 size_t byteSize = 82 totalSizeToAlloc<detail::OperandStorage, BlockOperand, Region, OpOperand>( 83 needsOperandStorage ? 1 : 0, numSuccessors, numRegions, numOperands); 84 size_t prefixByteSize = llvm::alignTo( 85 Operation::prefixAllocSize(numTrailingResults, numInlineResults), 86 alignof(Operation)); 87 char *mallocMem = reinterpret_cast<char *>(malloc(byteSize + prefixByteSize)); 88 void *rawMem = mallocMem + prefixByteSize; 89 90 // Create the new Operation. 91 Operation *op = 92 ::new (rawMem) Operation(location, name, numResults, numSuccessors, 93 numRegions, attributes, needsOperandStorage); 94 95 assert((numSuccessors == 0 || op->mightHaveTrait<OpTrait::IsTerminator>()) && 96 "unexpected successors in a non-terminator operation"); 97 98 // Initialize the results. 99 auto resultTypeIt = resultTypes.begin(); 100 for (unsigned i = 0; i < numInlineResults; ++i, ++resultTypeIt) 101 new (op->getInlineOpResult(i)) detail::InlineOpResult(*resultTypeIt, i); 102 for (unsigned i = 0; i < numTrailingResults; ++i, ++resultTypeIt) { 103 new (op->getOutOfLineOpResult(i)) 104 detail::OutOfLineOpResult(*resultTypeIt, i); 105 } 106 107 // Initialize the regions. 108 for (unsigned i = 0; i != numRegions; ++i) 109 new (&op->getRegion(i)) Region(op); 110 111 // Initialize the operands. 112 if (needsOperandStorage) { 113 new (&op->getOperandStorage()) detail::OperandStorage( 114 op, op->getTrailingObjects<OpOperand>(), operands); 115 } 116 117 // Initialize the successors. 118 auto blockOperands = op->getBlockOperands(); 119 for (unsigned i = 0; i != numSuccessors; ++i) 120 new (&blockOperands[i]) BlockOperand(op, successors[i]); 121 122 return op; 123 } 124 125 Operation::Operation(Location location, OperationName name, unsigned numResults, 126 unsigned numSuccessors, unsigned numRegions, 127 DictionaryAttr attributes, bool hasOperandStorage) 128 : location(location), numResults(numResults), numSuccs(numSuccessors), 129 numRegions(numRegions), hasOperandStorage(hasOperandStorage), name(name), 130 attrs(attributes) { 131 assert(attributes && "unexpected null attribute dictionary"); 132 #ifndef NDEBUG 133 if (!getDialect() && !getContext()->allowsUnregisteredDialects()) 134 llvm::report_fatal_error( 135 name.getStringRef() + 136 " created with unregistered dialect. If this is intended, please call " 137 "allowUnregisteredDialects() on the MLIRContext, or use " 138 "-allow-unregistered-dialect with the MLIR tool used."); 139 #endif 140 } 141 142 // Operations are deleted through the destroy() member because they are 143 // allocated via malloc. 144 Operation::~Operation() { 145 assert(block == nullptr && "operation destroyed but still in a block"); 146 #ifndef NDEBUG 147 if (!use_empty()) { 148 { 149 InFlightDiagnostic diag = 150 emitOpError("operation destroyed but still has uses"); 151 for (Operation *user : getUsers()) 152 diag.attachNote(user->getLoc()) << "- use: " << *user << "\n"; 153 } 154 llvm::report_fatal_error("operation destroyed but still has uses"); 155 } 156 #endif 157 // Explicitly run the destructors for the operands. 158 if (hasOperandStorage) 159 getOperandStorage().~OperandStorage(); 160 161 // Explicitly run the destructors for the successors. 162 for (auto &successor : getBlockOperands()) 163 successor.~BlockOperand(); 164 165 // Explicitly destroy the regions. 166 for (auto ®ion : getRegions()) 167 region.~Region(); 168 } 169 170 /// Destroy this operation or one of its subclasses. 171 void Operation::destroy() { 172 // Operations may have additional prefixed allocation, which needs to be 173 // accounted for here when computing the address to free. 174 char *rawMem = reinterpret_cast<char *>(this) - 175 llvm::alignTo(prefixAllocSize(), alignof(Operation)); 176 this->~Operation(); 177 free(rawMem); 178 } 179 180 /// Return true if this operation is a proper ancestor of the `other` 181 /// operation. 182 bool Operation::isProperAncestor(Operation *other) { 183 while ((other = other->getParentOp())) 184 if (this == other) 185 return true; 186 return false; 187 } 188 189 /// Replace any uses of 'from' with 'to' within this operation. 190 void Operation::replaceUsesOfWith(Value from, Value to) { 191 if (from == to) 192 return; 193 for (auto &operand : getOpOperands()) 194 if (operand.get() == from) 195 operand.set(to); 196 } 197 198 /// Replace the current operands of this operation with the ones provided in 199 /// 'operands'. 200 void Operation::setOperands(ValueRange operands) { 201 if (LLVM_LIKELY(hasOperandStorage)) 202 return getOperandStorage().setOperands(this, operands); 203 assert(operands.empty() && "setting operands without an operand storage"); 204 } 205 206 /// Replace the operands beginning at 'start' and ending at 'start' + 'length' 207 /// with the ones provided in 'operands'. 'operands' may be smaller or larger 208 /// than the range pointed to by 'start'+'length'. 209 void Operation::setOperands(unsigned start, unsigned length, 210 ValueRange operands) { 211 assert((start + length) <= getNumOperands() && 212 "invalid operand range specified"); 213 if (LLVM_LIKELY(hasOperandStorage)) 214 return getOperandStorage().setOperands(this, start, length, operands); 215 assert(operands.empty() && "setting operands without an operand storage"); 216 } 217 218 /// Insert the given operands into the operand list at the given 'index'. 219 void Operation::insertOperands(unsigned index, ValueRange operands) { 220 if (LLVM_LIKELY(hasOperandStorage)) 221 return setOperands(index, /*length=*/0, operands); 222 assert(operands.empty() && "inserting operands without an operand storage"); 223 } 224 225 //===----------------------------------------------------------------------===// 226 // Diagnostics 227 //===----------------------------------------------------------------------===// 228 229 /// Emit an error about fatal conditions with this operation, reporting up to 230 /// any diagnostic handlers that may be listening. 231 InFlightDiagnostic Operation::emitError(const Twine &message) { 232 InFlightDiagnostic diag = mlir::emitError(getLoc(), message); 233 if (getContext()->shouldPrintOpOnDiagnostic()) { 234 diag.attachNote(getLoc()) 235 .append("see current operation: ") 236 .appendOp(*this, OpPrintingFlags().printGenericOpForm()); 237 } 238 return diag; 239 } 240 241 /// Emit a warning about this operation, reporting up to any diagnostic 242 /// handlers that may be listening. 243 InFlightDiagnostic Operation::emitWarning(const Twine &message) { 244 InFlightDiagnostic diag = mlir::emitWarning(getLoc(), message); 245 if (getContext()->shouldPrintOpOnDiagnostic()) 246 diag.attachNote(getLoc()) << "see current operation: " << *this; 247 return diag; 248 } 249 250 /// Emit a remark about this operation, reporting up to any diagnostic 251 /// handlers that may be listening. 252 InFlightDiagnostic Operation::emitRemark(const Twine &message) { 253 InFlightDiagnostic diag = mlir::emitRemark(getLoc(), message); 254 if (getContext()->shouldPrintOpOnDiagnostic()) 255 diag.attachNote(getLoc()) << "see current operation: " << *this; 256 return diag; 257 } 258 259 //===----------------------------------------------------------------------===// 260 // Operation Ordering 261 //===----------------------------------------------------------------------===// 262 263 constexpr unsigned Operation::kInvalidOrderIdx; 264 constexpr unsigned Operation::kOrderStride; 265 266 /// Given an operation 'other' that is within the same parent block, return 267 /// whether the current operation is before 'other' in the operation list 268 /// of the parent block. 269 /// Note: This function has an average complexity of O(1), but worst case may 270 /// take O(N) where N is the number of operations within the parent block. 271 bool Operation::isBeforeInBlock(Operation *other) { 272 assert(block && "Operations without parent blocks have no order."); 273 assert(other && other->block == block && 274 "Expected other operation to have the same parent block."); 275 // If the order of the block is already invalid, directly recompute the 276 // parent. 277 if (!block->isOpOrderValid()) { 278 block->recomputeOpOrder(); 279 } else { 280 // Update the order either operation if necessary. 281 updateOrderIfNecessary(); 282 other->updateOrderIfNecessary(); 283 } 284 285 return orderIndex < other->orderIndex; 286 } 287 288 /// Update the order index of this operation of this operation if necessary, 289 /// potentially recomputing the order of the parent block. 290 void Operation::updateOrderIfNecessary() { 291 assert(block && "expected valid parent"); 292 293 // If the order is valid for this operation there is nothing to do. 294 if (hasValidOrder()) 295 return; 296 Operation *blockFront = &block->front(); 297 Operation *blockBack = &block->back(); 298 299 // This method is expected to only be invoked on blocks with more than one 300 // operation. 301 assert(blockFront != blockBack && "expected more than one operation"); 302 303 // If the operation is at the end of the block. 304 if (this == blockBack) { 305 Operation *prevNode = getPrevNode(); 306 if (!prevNode->hasValidOrder()) 307 return block->recomputeOpOrder(); 308 309 // Add the stride to the previous operation. 310 orderIndex = prevNode->orderIndex + kOrderStride; 311 return; 312 } 313 314 // If this is the first operation try to use the next operation to compute the 315 // ordering. 316 if (this == blockFront) { 317 Operation *nextNode = getNextNode(); 318 if (!nextNode->hasValidOrder()) 319 return block->recomputeOpOrder(); 320 // There is no order to give this operation. 321 if (nextNode->orderIndex == 0) 322 return block->recomputeOpOrder(); 323 324 // If we can't use the stride, just take the middle value left. This is safe 325 // because we know there is at least one valid index to assign to. 326 if (nextNode->orderIndex <= kOrderStride) 327 orderIndex = (nextNode->orderIndex / 2); 328 else 329 orderIndex = kOrderStride; 330 return; 331 } 332 333 // Otherwise, this operation is between two others. Place this operation in 334 // the middle of the previous and next if possible. 335 Operation *prevNode = getPrevNode(), *nextNode = getNextNode(); 336 if (!prevNode->hasValidOrder() || !nextNode->hasValidOrder()) 337 return block->recomputeOpOrder(); 338 unsigned prevOrder = prevNode->orderIndex, nextOrder = nextNode->orderIndex; 339 340 // Check to see if there is a valid order between the two. 341 if (prevOrder + 1 == nextOrder) 342 return block->recomputeOpOrder(); 343 orderIndex = prevOrder + ((nextOrder - prevOrder) / 2); 344 } 345 346 //===----------------------------------------------------------------------===// 347 // ilist_traits for Operation 348 //===----------------------------------------------------------------------===// 349 350 auto llvm::ilist_detail::SpecificNodeAccess< 351 typename llvm::ilist_detail::compute_node_options< 352 ::mlir::Operation>::type>::getNodePtr(pointer n) -> node_type * { 353 return NodeAccess::getNodePtr<OptionsT>(n); 354 } 355 356 auto llvm::ilist_detail::SpecificNodeAccess< 357 typename llvm::ilist_detail::compute_node_options< 358 ::mlir::Operation>::type>::getNodePtr(const_pointer n) 359 -> const node_type * { 360 return NodeAccess::getNodePtr<OptionsT>(n); 361 } 362 363 auto llvm::ilist_detail::SpecificNodeAccess< 364 typename llvm::ilist_detail::compute_node_options< 365 ::mlir::Operation>::type>::getValuePtr(node_type *n) -> pointer { 366 return NodeAccess::getValuePtr<OptionsT>(n); 367 } 368 369 auto llvm::ilist_detail::SpecificNodeAccess< 370 typename llvm::ilist_detail::compute_node_options< 371 ::mlir::Operation>::type>::getValuePtr(const node_type *n) 372 -> const_pointer { 373 return NodeAccess::getValuePtr<OptionsT>(n); 374 } 375 376 void llvm::ilist_traits<::mlir::Operation>::deleteNode(Operation *op) { 377 op->destroy(); 378 } 379 380 Block *llvm::ilist_traits<::mlir::Operation>::getContainingBlock() { 381 size_t offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr)))); 382 iplist<Operation> *anchor(static_cast<iplist<Operation> *>(this)); 383 return reinterpret_cast<Block *>(reinterpret_cast<char *>(anchor) - offset); 384 } 385 386 /// This is a trait method invoked when an operation is added to a block. We 387 /// keep the block pointer up to date. 388 void llvm::ilist_traits<::mlir::Operation>::addNodeToList(Operation *op) { 389 assert(!op->getBlock() && "already in an operation block!"); 390 op->block = getContainingBlock(); 391 392 // Invalidate the order on the operation. 393 op->orderIndex = Operation::kInvalidOrderIdx; 394 } 395 396 /// This is a trait method invoked when an operation is removed from a block. 397 /// We keep the block pointer up to date. 398 void llvm::ilist_traits<::mlir::Operation>::removeNodeFromList(Operation *op) { 399 assert(op->block && "not already in an operation block!"); 400 op->block = nullptr; 401 } 402 403 /// This is a trait method invoked when an operation is moved from one block 404 /// to another. We keep the block pointer up to date. 405 void llvm::ilist_traits<::mlir::Operation>::transferNodesFromList( 406 ilist_traits<Operation> &otherList, op_iterator first, op_iterator last) { 407 Block *curParent = getContainingBlock(); 408 409 // Invalidate the ordering of the parent block. 410 curParent->invalidateOpOrder(); 411 412 // If we are transferring operations within the same block, the block 413 // pointer doesn't need to be updated. 414 if (curParent == otherList.getContainingBlock()) 415 return; 416 417 // Update the 'block' member of each operation. 418 for (; first != last; ++first) 419 first->block = curParent; 420 } 421 422 /// Remove this operation (and its descendants) from its Block and delete 423 /// all of them. 424 void Operation::erase() { 425 if (auto *parent = getBlock()) 426 parent->getOperations().erase(this); 427 else 428 destroy(); 429 } 430 431 /// Remove the operation from its parent block, but don't delete it. 432 void Operation::remove() { 433 if (Block *parent = getBlock()) 434 parent->getOperations().remove(this); 435 } 436 437 /// Unlink this operation from its current block and insert it right before 438 /// `existingOp` which may be in the same or another block in the same 439 /// function. 440 void Operation::moveBefore(Operation *existingOp) { 441 moveBefore(existingOp->getBlock(), existingOp->getIterator()); 442 } 443 444 /// Unlink this operation from its current basic block and insert it right 445 /// before `iterator` in the specified basic block. 446 void Operation::moveBefore(Block *block, 447 llvm::iplist<Operation>::iterator iterator) { 448 block->getOperations().splice(iterator, getBlock()->getOperations(), 449 getIterator()); 450 } 451 452 /// Unlink this operation from its current block and insert it right after 453 /// `existingOp` which may be in the same or another block in the same function. 454 void Operation::moveAfter(Operation *existingOp) { 455 moveAfter(existingOp->getBlock(), existingOp->getIterator()); 456 } 457 458 /// Unlink this operation from its current block and insert it right after 459 /// `iterator` in the specified block. 460 void Operation::moveAfter(Block *block, 461 llvm::iplist<Operation>::iterator iterator) { 462 assert(iterator != block->end() && "cannot move after end of block"); 463 moveBefore(block, std::next(iterator)); 464 } 465 466 /// This drops all operand uses from this operation, which is an essential 467 /// step in breaking cyclic dependences between references when they are to 468 /// be deleted. 469 void Operation::dropAllReferences() { 470 for (auto &op : getOpOperands()) 471 op.drop(); 472 473 for (auto ®ion : getRegions()) 474 region.dropAllReferences(); 475 476 for (auto &dest : getBlockOperands()) 477 dest.drop(); 478 } 479 480 /// This drops all uses of any values defined by this operation or its nested 481 /// regions, wherever they are located. 482 void Operation::dropAllDefinedValueUses() { 483 dropAllUses(); 484 485 for (auto ®ion : getRegions()) 486 for (auto &block : region) 487 block.dropAllDefinedValueUses(); 488 } 489 490 void Operation::setSuccessor(Block *block, unsigned index) { 491 assert(index < getNumSuccessors()); 492 getBlockOperands()[index].set(block); 493 } 494 495 /// Attempt to fold this operation using the Op's registered foldHook. 496 LogicalResult Operation::fold(ArrayRef<Attribute> operands, 497 SmallVectorImpl<OpFoldResult> &results) { 498 // If we have a registered operation definition matching this one, use it to 499 // try to constant fold the operation. 500 Optional<RegisteredOperationName> info = getRegisteredInfo(); 501 if (info && succeeded(info->foldHook(this, operands, results))) 502 return success(); 503 504 // Otherwise, fall back on the dialect hook to handle it. 505 Dialect *dialect = getDialect(); 506 if (!dialect) 507 return failure(); 508 509 auto *interface = dyn_cast<DialectFoldInterface>(dialect); 510 if (!interface) 511 return failure(); 512 513 return interface->fold(this, operands, results); 514 } 515 516 /// Emit an error with the op name prefixed, like "'dim' op " which is 517 /// convenient for verifiers. 518 InFlightDiagnostic Operation::emitOpError(const Twine &message) { 519 return emitError() << "'" << getName() << "' op " << message; 520 } 521 522 //===----------------------------------------------------------------------===// 523 // Operation Cloning 524 //===----------------------------------------------------------------------===// 525 526 Operation::CloneOptions::CloneOptions() 527 : cloneRegionsFlag(false), cloneOperandsFlag(false) {} 528 529 Operation::CloneOptions::CloneOptions(bool cloneRegions, bool cloneOperands) 530 : cloneRegionsFlag(cloneRegions), cloneOperandsFlag(cloneOperands) {} 531 532 Operation::CloneOptions Operation::CloneOptions::all() { 533 return CloneOptions().cloneRegions().cloneOperands(); 534 } 535 536 Operation::CloneOptions &Operation::CloneOptions::cloneRegions(bool enable) { 537 cloneRegionsFlag = enable; 538 return *this; 539 } 540 541 Operation::CloneOptions &Operation::CloneOptions::cloneOperands(bool enable) { 542 cloneOperandsFlag = enable; 543 return *this; 544 } 545 546 /// Create a deep copy of this operation but keep the operation regions empty. 547 /// Operands are remapped using `mapper` (if present), and `mapper` is updated 548 /// to contain the results. The `mapResults` flag specifies whether the results 549 /// of the cloned operation should be added to the map. 550 Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper) { 551 return clone(mapper, CloneOptions::all().cloneRegions(false)); 552 } 553 554 Operation *Operation::cloneWithoutRegions() { 555 BlockAndValueMapping mapper; 556 return cloneWithoutRegions(mapper); 557 } 558 559 /// Create a deep copy of this operation, remapping any operands that use 560 /// values outside of the operation using the map that is provided (leaving 561 /// them alone if no entry is present). Replaces references to cloned 562 /// sub-operations to the corresponding operation that is copied, and adds 563 /// those mappings to the map. 564 Operation *Operation::clone(BlockAndValueMapping &mapper, 565 CloneOptions options) { 566 SmallVector<Value, 8> operands; 567 SmallVector<Block *, 2> successors; 568 569 // Remap the operands. 570 if (options.shouldCloneOperands()) { 571 operands.reserve(getNumOperands()); 572 for (auto opValue : getOperands()) 573 operands.push_back(mapper.lookupOrDefault(opValue)); 574 } 575 576 // Remap the successors. 577 successors.reserve(getNumSuccessors()); 578 for (Block *successor : getSuccessors()) 579 successors.push_back(mapper.lookupOrDefault(successor)); 580 581 // Create the new operation. 582 auto *newOp = create(getLoc(), getName(), getResultTypes(), operands, attrs, 583 successors, getNumRegions()); 584 585 // Clone the regions. 586 if (options.shouldCloneRegions()) { 587 for (unsigned i = 0; i != numRegions; ++i) 588 getRegion(i).cloneInto(&newOp->getRegion(i), mapper); 589 } 590 591 // Remember the mapping of any results. 592 for (unsigned i = 0, e = getNumResults(); i != e; ++i) 593 mapper.map(getResult(i), newOp->getResult(i)); 594 595 return newOp; 596 } 597 598 Operation *Operation::clone(CloneOptions options) { 599 BlockAndValueMapping mapper; 600 return clone(mapper, options); 601 } 602 603 //===----------------------------------------------------------------------===// 604 // OpState trait class. 605 //===----------------------------------------------------------------------===// 606 607 // The fallback for the parser is to try for a dialect operation parser. 608 // Otherwise, reject the custom assembly form. 609 ParseResult OpState::parse(OpAsmParser &parser, OperationState &result) { 610 if (auto parseFn = result.name.getDialect()->getParseOperationHook( 611 result.name.getStringRef())) 612 return (*parseFn)(parser, result); 613 return parser.emitError(parser.getNameLoc(), "has no custom assembly form"); 614 } 615 616 // The fallback for the printer is to try for a dialect operation printer. 617 // Otherwise, it prints the generic form. 618 void OpState::print(Operation *op, OpAsmPrinter &p, StringRef defaultDialect) { 619 if (auto printFn = op->getDialect()->getOperationPrinter(op)) { 620 printOpName(op, p, defaultDialect); 621 printFn(op, p); 622 } else { 623 p.printGenericOp(op); 624 } 625 } 626 627 /// Print an operation name, eliding the dialect prefix if necessary. 628 void OpState::printOpName(Operation *op, OpAsmPrinter &p, 629 StringRef defaultDialect) { 630 StringRef name = op->getName().getStringRef(); 631 if (name.startswith((defaultDialect + ".").str())) 632 name = name.drop_front(defaultDialect.size() + 1); 633 p.getStream() << name; 634 } 635 636 /// Emit an error about fatal conditions with this operation, reporting up to 637 /// any diagnostic handlers that may be listening. 638 InFlightDiagnostic OpState::emitError(const Twine &message) { 639 return getOperation()->emitError(message); 640 } 641 642 /// Emit an error with the op name prefixed, like "'dim' op " which is 643 /// convenient for verifiers. 644 InFlightDiagnostic OpState::emitOpError(const Twine &message) { 645 return getOperation()->emitOpError(message); 646 } 647 648 /// Emit a warning about this operation, reporting up to any diagnostic 649 /// handlers that may be listening. 650 InFlightDiagnostic OpState::emitWarning(const Twine &message) { 651 return getOperation()->emitWarning(message); 652 } 653 654 /// Emit a remark about this operation, reporting up to any diagnostic 655 /// handlers that may be listening. 656 InFlightDiagnostic OpState::emitRemark(const Twine &message) { 657 return getOperation()->emitRemark(message); 658 } 659 660 //===----------------------------------------------------------------------===// 661 // Op Trait implementations 662 //===----------------------------------------------------------------------===// 663 664 OpFoldResult OpTrait::impl::foldIdempotent(Operation *op) { 665 if (op->getNumOperands() == 1) { 666 auto *argumentOp = op->getOperand(0).getDefiningOp(); 667 if (argumentOp && op->getName() == argumentOp->getName()) { 668 // Replace the outer operation output with the inner operation. 669 return op->getOperand(0); 670 } 671 } else if (op->getOperand(0) == op->getOperand(1)) { 672 return op->getOperand(0); 673 } 674 675 return {}; 676 } 677 678 OpFoldResult OpTrait::impl::foldInvolution(Operation *op) { 679 auto *argumentOp = op->getOperand(0).getDefiningOp(); 680 if (argumentOp && op->getName() == argumentOp->getName()) { 681 // Replace the outer involutions output with inner's input. 682 return argumentOp->getOperand(0); 683 } 684 685 return {}; 686 } 687 688 LogicalResult OpTrait::impl::verifyZeroOperands(Operation *op) { 689 if (op->getNumOperands() != 0) 690 return op->emitOpError() << "requires zero operands"; 691 return success(); 692 } 693 694 LogicalResult OpTrait::impl::verifyOneOperand(Operation *op) { 695 if (op->getNumOperands() != 1) 696 return op->emitOpError() << "requires a single operand"; 697 return success(); 698 } 699 700 LogicalResult OpTrait::impl::verifyNOperands(Operation *op, 701 unsigned numOperands) { 702 if (op->getNumOperands() != numOperands) { 703 return op->emitOpError() << "expected " << numOperands 704 << " operands, but found " << op->getNumOperands(); 705 } 706 return success(); 707 } 708 709 LogicalResult OpTrait::impl::verifyAtLeastNOperands(Operation *op, 710 unsigned numOperands) { 711 if (op->getNumOperands() < numOperands) 712 return op->emitOpError() 713 << "expected " << numOperands << " or more operands, but found " 714 << op->getNumOperands(); 715 return success(); 716 } 717 718 /// If this is a vector type, or a tensor type, return the scalar element type 719 /// that it is built around, otherwise return the type unmodified. 720 static Type getTensorOrVectorElementType(Type type) { 721 if (auto vec = type.dyn_cast<VectorType>()) 722 return vec.getElementType(); 723 724 // Look through tensor<vector<...>> to find the underlying element type. 725 if (auto tensor = type.dyn_cast<TensorType>()) 726 return getTensorOrVectorElementType(tensor.getElementType()); 727 return type; 728 } 729 730 LogicalResult OpTrait::impl::verifyIsIdempotent(Operation *op) { 731 // FIXME: Add back check for no side effects on operation. 732 // Currently adding it would cause the shared library build 733 // to fail since there would be a dependency of IR on SideEffectInterfaces 734 // which is cyclical. 735 return success(); 736 } 737 738 LogicalResult OpTrait::impl::verifyIsInvolution(Operation *op) { 739 // FIXME: Add back check for no side effects on operation. 740 // Currently adding it would cause the shared library build 741 // to fail since there would be a dependency of IR on SideEffectInterfaces 742 // which is cyclical. 743 return success(); 744 } 745 746 LogicalResult 747 OpTrait::impl::verifyOperandsAreSignlessIntegerLike(Operation *op) { 748 for (auto opType : op->getOperandTypes()) { 749 auto type = getTensorOrVectorElementType(opType); 750 if (!type.isSignlessIntOrIndex()) 751 return op->emitOpError() << "requires an integer or index type"; 752 } 753 return success(); 754 } 755 756 LogicalResult OpTrait::impl::verifyOperandsAreFloatLike(Operation *op) { 757 for (auto opType : op->getOperandTypes()) { 758 auto type = getTensorOrVectorElementType(opType); 759 if (!type.isa<FloatType>()) 760 return op->emitOpError("requires a float type"); 761 } 762 return success(); 763 } 764 765 LogicalResult OpTrait::impl::verifySameTypeOperands(Operation *op) { 766 // Zero or one operand always have the "same" type. 767 unsigned nOperands = op->getNumOperands(); 768 if (nOperands < 2) 769 return success(); 770 771 auto type = op->getOperand(0).getType(); 772 for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) 773 if (opType != type) 774 return op->emitOpError() << "requires all operands to have the same type"; 775 return success(); 776 } 777 778 LogicalResult OpTrait::impl::verifyZeroRegions(Operation *op) { 779 if (op->getNumRegions() != 0) 780 return op->emitOpError() << "requires zero regions"; 781 return success(); 782 } 783 784 LogicalResult OpTrait::impl::verifyOneRegion(Operation *op) { 785 if (op->getNumRegions() != 1) 786 return op->emitOpError() << "requires one region"; 787 return success(); 788 } 789 790 LogicalResult OpTrait::impl::verifyNRegions(Operation *op, 791 unsigned numRegions) { 792 if (op->getNumRegions() != numRegions) 793 return op->emitOpError() << "expected " << numRegions << " regions"; 794 return success(); 795 } 796 797 LogicalResult OpTrait::impl::verifyAtLeastNRegions(Operation *op, 798 unsigned numRegions) { 799 if (op->getNumRegions() < numRegions) 800 return op->emitOpError() << "expected " << numRegions << " or more regions"; 801 return success(); 802 } 803 804 LogicalResult OpTrait::impl::verifyZeroResults(Operation *op) { 805 if (op->getNumResults() != 0) 806 return op->emitOpError() << "requires zero results"; 807 return success(); 808 } 809 810 LogicalResult OpTrait::impl::verifyOneResult(Operation *op) { 811 if (op->getNumResults() != 1) 812 return op->emitOpError() << "requires one result"; 813 return success(); 814 } 815 816 LogicalResult OpTrait::impl::verifyNResults(Operation *op, 817 unsigned numOperands) { 818 if (op->getNumResults() != numOperands) 819 return op->emitOpError() << "expected " << numOperands << " results"; 820 return success(); 821 } 822 823 LogicalResult OpTrait::impl::verifyAtLeastNResults(Operation *op, 824 unsigned numOperands) { 825 if (op->getNumResults() < numOperands) 826 return op->emitOpError() 827 << "expected " << numOperands << " or more results"; 828 return success(); 829 } 830 831 LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) { 832 if (failed(verifyAtLeastNOperands(op, 1))) 833 return failure(); 834 835 if (failed(verifyCompatibleShapes(op->getOperandTypes()))) 836 return op->emitOpError() << "requires the same shape for all operands"; 837 838 return success(); 839 } 840 841 LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) { 842 if (failed(verifyAtLeastNOperands(op, 1)) || 843 failed(verifyAtLeastNResults(op, 1))) 844 return failure(); 845 846 SmallVector<Type, 8> types(op->getOperandTypes()); 847 types.append(llvm::to_vector<4>(op->getResultTypes())); 848 849 if (failed(verifyCompatibleShapes(types))) 850 return op->emitOpError() 851 << "requires the same shape for all operands and results"; 852 853 return success(); 854 } 855 856 LogicalResult OpTrait::impl::verifySameOperandsElementType(Operation *op) { 857 if (failed(verifyAtLeastNOperands(op, 1))) 858 return failure(); 859 auto elementType = getElementTypeOrSelf(op->getOperand(0)); 860 861 for (auto operand : llvm::drop_begin(op->getOperands(), 1)) { 862 if (getElementTypeOrSelf(operand) != elementType) 863 return op->emitOpError("requires the same element type for all operands"); 864 } 865 866 return success(); 867 } 868 869 LogicalResult 870 OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) { 871 if (failed(verifyAtLeastNOperands(op, 1)) || 872 failed(verifyAtLeastNResults(op, 1))) 873 return failure(); 874 875 auto elementType = getElementTypeOrSelf(op->getResult(0)); 876 877 // Verify result element type matches first result's element type. 878 for (auto result : llvm::drop_begin(op->getResults(), 1)) { 879 if (getElementTypeOrSelf(result) != elementType) 880 return op->emitOpError( 881 "requires the same element type for all operands and results"); 882 } 883 884 // Verify operand's element type matches first result's element type. 885 for (auto operand : op->getOperands()) { 886 if (getElementTypeOrSelf(operand) != elementType) 887 return op->emitOpError( 888 "requires the same element type for all operands and results"); 889 } 890 891 return success(); 892 } 893 894 LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) { 895 if (failed(verifyAtLeastNOperands(op, 1)) || 896 failed(verifyAtLeastNResults(op, 1))) 897 return failure(); 898 899 auto type = op->getResult(0).getType(); 900 auto elementType = getElementTypeOrSelf(type); 901 for (auto resultType : llvm::drop_begin(op->getResultTypes())) { 902 if (getElementTypeOrSelf(resultType) != elementType || 903 failed(verifyCompatibleShape(resultType, type))) 904 return op->emitOpError() 905 << "requires the same type for all operands and results"; 906 } 907 for (auto opType : op->getOperandTypes()) { 908 if (getElementTypeOrSelf(opType) != elementType || 909 failed(verifyCompatibleShape(opType, type))) 910 return op->emitOpError() 911 << "requires the same type for all operands and results"; 912 } 913 return success(); 914 } 915 916 LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) { 917 Block *block = op->getBlock(); 918 // Verify that the operation is at the end of the respective parent block. 919 if (!block || &block->back() != op) 920 return op->emitOpError("must be the last operation in the parent block"); 921 return success(); 922 } 923 924 static LogicalResult verifyTerminatorSuccessors(Operation *op) { 925 auto *parent = op->getParentRegion(); 926 927 // Verify that the operands lines up with the BB arguments in the successor. 928 for (Block *succ : op->getSuccessors()) 929 if (succ->getParent() != parent) 930 return op->emitError("reference to block defined in another region"); 931 return success(); 932 } 933 934 LogicalResult OpTrait::impl::verifyZeroSuccessors(Operation *op) { 935 if (op->getNumSuccessors() != 0) { 936 return op->emitOpError("requires 0 successors but found ") 937 << op->getNumSuccessors(); 938 } 939 return success(); 940 } 941 942 LogicalResult OpTrait::impl::verifyOneSuccessor(Operation *op) { 943 if (op->getNumSuccessors() != 1) { 944 return op->emitOpError("requires 1 successor but found ") 945 << op->getNumSuccessors(); 946 } 947 return verifyTerminatorSuccessors(op); 948 } 949 LogicalResult OpTrait::impl::verifyNSuccessors(Operation *op, 950 unsigned numSuccessors) { 951 if (op->getNumSuccessors() != numSuccessors) { 952 return op->emitOpError("requires ") 953 << numSuccessors << " successors but found " 954 << op->getNumSuccessors(); 955 } 956 return verifyTerminatorSuccessors(op); 957 } 958 LogicalResult OpTrait::impl::verifyAtLeastNSuccessors(Operation *op, 959 unsigned numSuccessors) { 960 if (op->getNumSuccessors() < numSuccessors) { 961 return op->emitOpError("requires at least ") 962 << numSuccessors << " successors but found " 963 << op->getNumSuccessors(); 964 } 965 return verifyTerminatorSuccessors(op); 966 } 967 968 LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) { 969 for (auto resultType : op->getResultTypes()) { 970 auto elementType = getTensorOrVectorElementType(resultType); 971 bool isBoolType = elementType.isInteger(1); 972 if (!isBoolType) 973 return op->emitOpError() << "requires a bool result type"; 974 } 975 976 return success(); 977 } 978 979 LogicalResult OpTrait::impl::verifyResultsAreFloatLike(Operation *op) { 980 for (auto resultType : op->getResultTypes()) 981 if (!getTensorOrVectorElementType(resultType).isa<FloatType>()) 982 return op->emitOpError() << "requires a floating point type"; 983 984 return success(); 985 } 986 987 LogicalResult 988 OpTrait::impl::verifyResultsAreSignlessIntegerLike(Operation *op) { 989 for (auto resultType : op->getResultTypes()) 990 if (!getTensorOrVectorElementType(resultType).isSignlessIntOrIndex()) 991 return op->emitOpError() << "requires an integer or index type"; 992 return success(); 993 } 994 995 LogicalResult OpTrait::impl::verifyValueSizeAttr(Operation *op, 996 StringRef attrName, 997 StringRef valueGroupName, 998 size_t expectedCount) { 999 auto sizeAttr = op->getAttrOfType<DenseIntElementsAttr>(attrName); 1000 if (!sizeAttr) 1001 return op->emitOpError("requires 1D i32 elements attribute '") 1002 << attrName << "'"; 1003 1004 auto sizeAttrType = sizeAttr.getType(); 1005 if (sizeAttrType.getRank() != 1 || 1006 !sizeAttrType.getElementType().isInteger(32)) 1007 return op->emitOpError("requires 1D i32 elements attribute '") 1008 << attrName << "'"; 1009 1010 if (llvm::any_of(sizeAttr.getValues<APInt>(), [](const APInt &element) { 1011 return !element.isNonNegative(); 1012 })) 1013 return op->emitOpError("'") 1014 << attrName << "' attribute cannot have negative elements"; 1015 1016 size_t totalCount = std::accumulate( 1017 sizeAttr.begin(), sizeAttr.end(), 0, 1018 [](unsigned all, const APInt &one) { return all + one.getZExtValue(); }); 1019 1020 if (totalCount != expectedCount) 1021 return op->emitOpError() 1022 << valueGroupName << " count (" << expectedCount 1023 << ") does not match with the total size (" << totalCount 1024 << ") specified in attribute '" << attrName << "'"; 1025 return success(); 1026 } 1027 1028 LogicalResult OpTrait::impl::verifyOperandSizeAttr(Operation *op, 1029 StringRef attrName) { 1030 return verifyValueSizeAttr(op, attrName, "operand", op->getNumOperands()); 1031 } 1032 1033 LogicalResult OpTrait::impl::verifyResultSizeAttr(Operation *op, 1034 StringRef attrName) { 1035 return verifyValueSizeAttr(op, attrName, "result", op->getNumResults()); 1036 } 1037 1038 LogicalResult OpTrait::impl::verifyNoRegionArguments(Operation *op) { 1039 for (Region ®ion : op->getRegions()) { 1040 if (region.empty()) 1041 continue; 1042 1043 if (region.getNumArguments() != 0) { 1044 if (op->getNumRegions() > 1) 1045 return op->emitOpError("region #") 1046 << region.getRegionNumber() << " should have no arguments"; 1047 return op->emitOpError("region should have no arguments"); 1048 } 1049 } 1050 return success(); 1051 } 1052 1053 LogicalResult OpTrait::impl::verifyElementwise(Operation *op) { 1054 auto isMappableType = [](Type type) { 1055 return type.isa<VectorType, TensorType>(); 1056 }; 1057 auto resultMappableTypes = llvm::to_vector<1>( 1058 llvm::make_filter_range(op->getResultTypes(), isMappableType)); 1059 auto operandMappableTypes = llvm::to_vector<2>( 1060 llvm::make_filter_range(op->getOperandTypes(), isMappableType)); 1061 1062 // If the op only has scalar operand/result types, then we have nothing to 1063 // check. 1064 if (resultMappableTypes.empty() && operandMappableTypes.empty()) 1065 return success(); 1066 1067 if (!resultMappableTypes.empty() && operandMappableTypes.empty()) 1068 return op->emitOpError("if a result is non-scalar, then at least one " 1069 "operand must be non-scalar"); 1070 1071 assert(!operandMappableTypes.empty()); 1072 1073 if (resultMappableTypes.empty()) 1074 return op->emitOpError("if an operand is non-scalar, then there must be at " 1075 "least one non-scalar result"); 1076 1077 if (resultMappableTypes.size() != op->getNumResults()) 1078 return op->emitOpError( 1079 "if an operand is non-scalar, then all results must be non-scalar"); 1080 1081 SmallVector<Type, 4> types = llvm::to_vector<2>( 1082 llvm::concat<Type>(operandMappableTypes, resultMappableTypes)); 1083 TypeID expectedBaseTy = types.front().getTypeID(); 1084 if (!llvm::all_of(types, 1085 [&](Type t) { return t.getTypeID() == expectedBaseTy; }) || 1086 failed(verifyCompatibleShapes(types))) { 1087 return op->emitOpError() << "all non-scalar operands/results must have the " 1088 "same shape and base type"; 1089 } 1090 1091 return success(); 1092 } 1093 1094 /// Check for any values used by operations regions attached to the 1095 /// specified "IsIsolatedFromAbove" operation defined outside of it. 1096 LogicalResult OpTrait::impl::verifyIsIsolatedFromAbove(Operation *isolatedOp) { 1097 assert(isolatedOp->hasTrait<OpTrait::IsIsolatedFromAbove>() && 1098 "Intended to check IsolatedFromAbove ops"); 1099 1100 // List of regions to analyze. Each region is processed independently, with 1101 // respect to the common `limit` region, so we can look at them in any order. 1102 // Therefore, use a simple vector and push/pop back the current region. 1103 SmallVector<Region *, 8> pendingRegions; 1104 for (auto ®ion : isolatedOp->getRegions()) { 1105 pendingRegions.push_back(®ion); 1106 1107 // Traverse all operations in the region. 1108 while (!pendingRegions.empty()) { 1109 for (Operation &op : pendingRegions.pop_back_val()->getOps()) { 1110 for (Value operand : op.getOperands()) { 1111 // Check that any value that is used by an operation is defined in the 1112 // same region as either an operation result. 1113 auto *operandRegion = operand.getParentRegion(); 1114 if (!operandRegion) 1115 return op.emitError("operation's operand is unlinked"); 1116 if (!region.isAncestor(operandRegion)) { 1117 return op.emitOpError("using value defined outside the region") 1118 .attachNote(isolatedOp->getLoc()) 1119 << "required by region isolation constraints"; 1120 } 1121 } 1122 1123 // Schedule any regions in the operation for further checking. Don't 1124 // recurse into other IsolatedFromAbove ops, because they will check 1125 // themselves. 1126 if (op.getNumRegions() && 1127 !op.hasTrait<OpTrait::IsIsolatedFromAbove>()) { 1128 for (Region &subRegion : op.getRegions()) 1129 pendingRegions.push_back(&subRegion); 1130 } 1131 } 1132 } 1133 } 1134 1135 return success(); 1136 } 1137 1138 bool OpTrait::hasElementwiseMappableTraits(Operation *op) { 1139 return op->hasTrait<Elementwise>() && op->hasTrait<Scalarizable>() && 1140 op->hasTrait<Vectorizable>() && op->hasTrait<Tensorizable>(); 1141 } 1142 1143 //===----------------------------------------------------------------------===// 1144 // CastOpInterface 1145 //===----------------------------------------------------------------------===// 1146 1147 /// Attempt to fold the given cast operation. 1148 LogicalResult 1149 impl::foldCastInterfaceOp(Operation *op, ArrayRef<Attribute> attrOperands, 1150 SmallVectorImpl<OpFoldResult> &foldResults) { 1151 OperandRange operands = op->getOperands(); 1152 if (operands.empty()) 1153 return failure(); 1154 ResultRange results = op->getResults(); 1155 1156 // Check for the case where the input and output types match 1-1. 1157 if (operands.getTypes() == results.getTypes()) { 1158 foldResults.append(operands.begin(), operands.end()); 1159 return success(); 1160 } 1161 1162 return failure(); 1163 } 1164 1165 /// Attempt to verify the given cast operation. 1166 LogicalResult impl::verifyCastInterfaceOp( 1167 Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible) { 1168 auto resultTypes = op->getResultTypes(); 1169 if (llvm::empty(resultTypes)) 1170 return op->emitOpError() 1171 << "expected at least one result for cast operation"; 1172 1173 auto operandTypes = op->getOperandTypes(); 1174 if (!areCastCompatible(operandTypes, resultTypes)) { 1175 InFlightDiagnostic diag = op->emitOpError("operand type"); 1176 if (llvm::empty(operandTypes)) 1177 diag << "s []"; 1178 else if (llvm::size(operandTypes) == 1) 1179 diag << " " << *operandTypes.begin(); 1180 else 1181 diag << "s " << operandTypes; 1182 return diag << " and result type" << (resultTypes.size() == 1 ? " " : "s ") 1183 << resultTypes << " are cast incompatible"; 1184 } 1185 1186 return success(); 1187 } 1188 1189 //===----------------------------------------------------------------------===// 1190 // Misc. utils 1191 //===----------------------------------------------------------------------===// 1192 1193 /// Insert an operation, generated by `buildTerminatorOp`, at the end of the 1194 /// region's only block if it does not have a terminator already. If the region 1195 /// is empty, insert a new block first. `buildTerminatorOp` should return the 1196 /// terminator operation to insert. 1197 void impl::ensureRegionTerminator( 1198 Region ®ion, OpBuilder &builder, Location loc, 1199 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { 1200 OpBuilder::InsertionGuard guard(builder); 1201 if (region.empty()) 1202 builder.createBlock(®ion); 1203 1204 Block &block = region.back(); 1205 if (!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>()) 1206 return; 1207 1208 builder.setInsertionPointToEnd(&block); 1209 builder.insert(buildTerminatorOp(builder, loc)); 1210 } 1211 1212 /// Create a simple OpBuilder and forward to the OpBuilder version of this 1213 /// function. 1214 void impl::ensureRegionTerminator( 1215 Region ®ion, Builder &builder, Location loc, 1216 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { 1217 OpBuilder opBuilder(builder.getContext()); 1218 ensureRegionTerminator(region, opBuilder, loc, buildTerminatorOp); 1219 } 1220