1 //===- Operation.cpp - Operation support code -----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Operation.h" 10 #include "mlir/IR/BlockAndValueMapping.h" 11 #include "mlir/IR/BuiltinTypes.h" 12 #include "mlir/IR/Dialect.h" 13 #include "mlir/IR/OpImplementation.h" 14 #include "mlir/IR/PatternMatch.h" 15 #include "mlir/IR/TypeUtilities.h" 16 #include "mlir/Interfaces/FoldInterfaces.h" 17 #include "llvm/ADT/StringExtras.h" 18 #include <numeric> 19 20 using namespace mlir; 21 22 //===----------------------------------------------------------------------===// 23 // Operation 24 //===----------------------------------------------------------------------===// 25 26 /// Create a new Operation with the specific fields. 27 Operation *Operation::create(Location location, OperationName name, 28 TypeRange resultTypes, ValueRange operands, 29 ArrayRef<NamedAttribute> attributes, 30 BlockRange successors, unsigned numRegions) { 31 return create(location, name, resultTypes, operands, 32 DictionaryAttr::get(location.getContext(), attributes), 33 successors, numRegions); 34 } 35 36 /// Create a new Operation from operation state. 37 Operation *Operation::create(const OperationState &state) { 38 return create(state.location, state.name, state.types, state.operands, 39 state.attributes.getDictionary(state.getContext()), 40 state.successors, state.regions); 41 } 42 43 /// Create a new Operation with the specific fields. 44 Operation *Operation::create(Location location, OperationName name, 45 TypeRange resultTypes, ValueRange operands, 46 DictionaryAttr attributes, BlockRange successors, 47 RegionRange regions) { 48 unsigned numRegions = regions.size(); 49 Operation *op = create(location, name, resultTypes, operands, attributes, 50 successors, numRegions); 51 for (unsigned i = 0; i < numRegions; ++i) 52 if (regions[i]) 53 op->getRegion(i).takeBody(*regions[i]); 54 return op; 55 } 56 57 /// Overload of create that takes an existing DictionaryAttr to avoid 58 /// unnecessarily uniquing a list of attributes. 59 Operation *Operation::create(Location location, OperationName name, 60 TypeRange resultTypes, ValueRange operands, 61 DictionaryAttr attributes, BlockRange successors, 62 unsigned numRegions) { 63 assert(llvm::all_of(resultTypes, [](Type t) { return t; }) && 64 "unexpected null result type"); 65 66 // We only need to allocate additional memory for a subset of results. 67 unsigned numTrailingResults = OpResult::getNumTrailing(resultTypes.size()); 68 unsigned numInlineResults = OpResult::getNumInline(resultTypes.size()); 69 unsigned numSuccessors = successors.size(); 70 unsigned numOperands = operands.size(); 71 unsigned numResults = resultTypes.size(); 72 73 // If the operation is known to have no operands, don't allocate an operand 74 // storage. 75 bool needsOperandStorage = 76 operands.empty() ? !name.hasTrait<OpTrait::ZeroOperands>() : true; 77 78 // Compute the byte size for the operation and the operand storage. This takes 79 // into account the size of the operation, its trailing objects, and its 80 // prefixed objects. 81 size_t byteSize = 82 totalSizeToAlloc<detail::OperandStorage, BlockOperand, Region, OpOperand>( 83 needsOperandStorage ? 1 : 0, numSuccessors, numRegions, numOperands); 84 size_t prefixByteSize = llvm::alignTo( 85 Operation::prefixAllocSize(numTrailingResults, numInlineResults), 86 alignof(Operation)); 87 char *mallocMem = reinterpret_cast<char *>(malloc(byteSize + prefixByteSize)); 88 void *rawMem = mallocMem + prefixByteSize; 89 90 // Create the new Operation. 91 Operation *op = 92 ::new (rawMem) Operation(location, name, numResults, numSuccessors, 93 numRegions, attributes, needsOperandStorage); 94 95 assert((numSuccessors == 0 || op->mightHaveTrait<OpTrait::IsTerminator>()) && 96 "unexpected successors in a non-terminator operation"); 97 98 // Initialize the results. 99 auto resultTypeIt = resultTypes.begin(); 100 for (unsigned i = 0; i < numInlineResults; ++i, ++resultTypeIt) 101 new (op->getInlineOpResult(i)) detail::InlineOpResult(*resultTypeIt, i); 102 for (unsigned i = 0; i < numTrailingResults; ++i, ++resultTypeIt) { 103 new (op->getOutOfLineOpResult(i)) 104 detail::OutOfLineOpResult(*resultTypeIt, i); 105 } 106 107 // Initialize the regions. 108 for (unsigned i = 0; i != numRegions; ++i) 109 new (&op->getRegion(i)) Region(op); 110 111 // Initialize the operands. 112 if (needsOperandStorage) { 113 new (&op->getOperandStorage()) detail::OperandStorage( 114 op, op->getTrailingObjects<OpOperand>(), operands); 115 } 116 117 // Initialize the successors. 118 auto blockOperands = op->getBlockOperands(); 119 for (unsigned i = 0; i != numSuccessors; ++i) 120 new (&blockOperands[i]) BlockOperand(op, successors[i]); 121 122 return op; 123 } 124 125 Operation::Operation(Location location, OperationName name, unsigned numResults, 126 unsigned numSuccessors, unsigned numRegions, 127 DictionaryAttr attributes, bool hasOperandStorage) 128 : location(location), numResults(numResults), numSuccs(numSuccessors), 129 numRegions(numRegions), hasOperandStorage(hasOperandStorage), name(name), 130 attrs(attributes) { 131 assert(attributes && "unexpected null attribute dictionary"); 132 #ifndef NDEBUG 133 if (!getDialect() && !getContext()->allowsUnregisteredDialects()) 134 llvm::report_fatal_error( 135 name.getStringRef() + 136 " created with unregistered dialect. If this is intended, please call " 137 "allowUnregisteredDialects() on the MLIRContext, or use " 138 "-allow-unregistered-dialect with the MLIR tool used."); 139 #endif 140 } 141 142 // Operations are deleted through the destroy() member because they are 143 // allocated via malloc. 144 Operation::~Operation() { 145 assert(block == nullptr && "operation destroyed but still in a block"); 146 #ifndef NDEBUG 147 if (!use_empty()) { 148 { 149 InFlightDiagnostic diag = 150 emitOpError("operation destroyed but still has uses"); 151 for (Operation *user : getUsers()) 152 diag.attachNote(user->getLoc()) << "- use: " << *user << "\n"; 153 } 154 llvm::report_fatal_error("operation destroyed but still has uses"); 155 } 156 #endif 157 // Explicitly run the destructors for the operands. 158 if (hasOperandStorage) 159 getOperandStorage().~OperandStorage(); 160 161 // Explicitly run the destructors for the successors. 162 for (auto &successor : getBlockOperands()) 163 successor.~BlockOperand(); 164 165 // Explicitly destroy the regions. 166 for (auto ®ion : getRegions()) 167 region.~Region(); 168 } 169 170 /// Destroy this operation or one of its subclasses. 171 void Operation::destroy() { 172 // Operations may have additional prefixed allocation, which needs to be 173 // accounted for here when computing the address to free. 174 char *rawMem = reinterpret_cast<char *>(this) - 175 llvm::alignTo(prefixAllocSize(), alignof(Operation)); 176 this->~Operation(); 177 free(rawMem); 178 } 179 180 /// Return true if this operation is a proper ancestor of the `other` 181 /// operation. 182 bool Operation::isProperAncestor(Operation *other) { 183 while ((other = other->getParentOp())) 184 if (this == other) 185 return true; 186 return false; 187 } 188 189 /// Replace any uses of 'from' with 'to' within this operation. 190 void Operation::replaceUsesOfWith(Value from, Value to) { 191 if (from == to) 192 return; 193 for (auto &operand : getOpOperands()) 194 if (operand.get() == from) 195 operand.set(to); 196 } 197 198 /// Replace the current operands of this operation with the ones provided in 199 /// 'operands'. 200 void Operation::setOperands(ValueRange operands) { 201 if (LLVM_LIKELY(hasOperandStorage)) 202 return getOperandStorage().setOperands(this, operands); 203 assert(operands.empty() && "setting operands without an operand storage"); 204 } 205 206 /// Replace the operands beginning at 'start' and ending at 'start' + 'length' 207 /// with the ones provided in 'operands'. 'operands' may be smaller or larger 208 /// than the range pointed to by 'start'+'length'. 209 void Operation::setOperands(unsigned start, unsigned length, 210 ValueRange operands) { 211 assert((start + length) <= getNumOperands() && 212 "invalid operand range specified"); 213 if (LLVM_LIKELY(hasOperandStorage)) 214 return getOperandStorage().setOperands(this, start, length, operands); 215 assert(operands.empty() && "setting operands without an operand storage"); 216 } 217 218 /// Insert the given operands into the operand list at the given 'index'. 219 void Operation::insertOperands(unsigned index, ValueRange operands) { 220 if (LLVM_LIKELY(hasOperandStorage)) 221 return setOperands(index, /*length=*/0, operands); 222 assert(operands.empty() && "inserting operands without an operand storage"); 223 } 224 225 //===----------------------------------------------------------------------===// 226 // Diagnostics 227 //===----------------------------------------------------------------------===// 228 229 /// Emit an error about fatal conditions with this operation, reporting up to 230 /// any diagnostic handlers that may be listening. 231 InFlightDiagnostic Operation::emitError(const Twine &message) { 232 InFlightDiagnostic diag = mlir::emitError(getLoc(), message); 233 if (getContext()->shouldPrintOpOnDiagnostic()) { 234 diag.attachNote(getLoc()) 235 .append("see current operation: ") 236 .appendOp(*this, OpPrintingFlags().printGenericOpForm()); 237 } 238 return diag; 239 } 240 241 /// Emit a warning about this operation, reporting up to any diagnostic 242 /// handlers that may be listening. 243 InFlightDiagnostic Operation::emitWarning(const Twine &message) { 244 InFlightDiagnostic diag = mlir::emitWarning(getLoc(), message); 245 if (getContext()->shouldPrintOpOnDiagnostic()) 246 diag.attachNote(getLoc()) << "see current operation: " << *this; 247 return diag; 248 } 249 250 /// Emit a remark about this operation, reporting up to any diagnostic 251 /// handlers that may be listening. 252 InFlightDiagnostic Operation::emitRemark(const Twine &message) { 253 InFlightDiagnostic diag = mlir::emitRemark(getLoc(), message); 254 if (getContext()->shouldPrintOpOnDiagnostic()) 255 diag.attachNote(getLoc()) << "see current operation: " << *this; 256 return diag; 257 } 258 259 //===----------------------------------------------------------------------===// 260 // Operation Ordering 261 //===----------------------------------------------------------------------===// 262 263 constexpr unsigned Operation::kInvalidOrderIdx; 264 constexpr unsigned Operation::kOrderStride; 265 266 /// Given an operation 'other' that is within the same parent block, return 267 /// whether the current operation is before 'other' in the operation list 268 /// of the parent block. 269 /// Note: This function has an average complexity of O(1), but worst case may 270 /// take O(N) where N is the number of operations within the parent block. 271 bool Operation::isBeforeInBlock(Operation *other) { 272 assert(block && "Operations without parent blocks have no order."); 273 assert(other && other->block == block && 274 "Expected other operation to have the same parent block."); 275 // If the order of the block is already invalid, directly recompute the 276 // parent. 277 if (!block->isOpOrderValid()) { 278 block->recomputeOpOrder(); 279 } else { 280 // Update the order either operation if necessary. 281 updateOrderIfNecessary(); 282 other->updateOrderIfNecessary(); 283 } 284 285 return orderIndex < other->orderIndex; 286 } 287 288 /// Update the order index of this operation of this operation if necessary, 289 /// potentially recomputing the order of the parent block. 290 void Operation::updateOrderIfNecessary() { 291 assert(block && "expected valid parent"); 292 293 // If the order is valid for this operation there is nothing to do. 294 if (hasValidOrder()) 295 return; 296 Operation *blockFront = &block->front(); 297 Operation *blockBack = &block->back(); 298 299 // This method is expected to only be invoked on blocks with more than one 300 // operation. 301 assert(blockFront != blockBack && "expected more than one operation"); 302 303 // If the operation is at the end of the block. 304 if (this == blockBack) { 305 Operation *prevNode = getPrevNode(); 306 if (!prevNode->hasValidOrder()) 307 return block->recomputeOpOrder(); 308 309 // Add the stride to the previous operation. 310 orderIndex = prevNode->orderIndex + kOrderStride; 311 return; 312 } 313 314 // If this is the first operation try to use the next operation to compute the 315 // ordering. 316 if (this == blockFront) { 317 Operation *nextNode = getNextNode(); 318 if (!nextNode->hasValidOrder()) 319 return block->recomputeOpOrder(); 320 // There is no order to give this operation. 321 if (nextNode->orderIndex == 0) 322 return block->recomputeOpOrder(); 323 324 // If we can't use the stride, just take the middle value left. This is safe 325 // because we know there is at least one valid index to assign to. 326 if (nextNode->orderIndex <= kOrderStride) 327 orderIndex = (nextNode->orderIndex / 2); 328 else 329 orderIndex = kOrderStride; 330 return; 331 } 332 333 // Otherwise, this operation is between two others. Place this operation in 334 // the middle of the previous and next if possible. 335 Operation *prevNode = getPrevNode(), *nextNode = getNextNode(); 336 if (!prevNode->hasValidOrder() || !nextNode->hasValidOrder()) 337 return block->recomputeOpOrder(); 338 unsigned prevOrder = prevNode->orderIndex, nextOrder = nextNode->orderIndex; 339 340 // Check to see if there is a valid order between the two. 341 if (prevOrder + 1 == nextOrder) 342 return block->recomputeOpOrder(); 343 orderIndex = prevOrder + ((nextOrder - prevOrder) / 2); 344 } 345 346 //===----------------------------------------------------------------------===// 347 // ilist_traits for Operation 348 //===----------------------------------------------------------------------===// 349 350 auto llvm::ilist_detail::SpecificNodeAccess< 351 typename llvm::ilist_detail::compute_node_options< 352 ::mlir::Operation>::type>::getNodePtr(pointer n) -> node_type * { 353 return NodeAccess::getNodePtr<OptionsT>(n); 354 } 355 356 auto llvm::ilist_detail::SpecificNodeAccess< 357 typename llvm::ilist_detail::compute_node_options< 358 ::mlir::Operation>::type>::getNodePtr(const_pointer n) 359 -> const node_type * { 360 return NodeAccess::getNodePtr<OptionsT>(n); 361 } 362 363 auto llvm::ilist_detail::SpecificNodeAccess< 364 typename llvm::ilist_detail::compute_node_options< 365 ::mlir::Operation>::type>::getValuePtr(node_type *n) -> pointer { 366 return NodeAccess::getValuePtr<OptionsT>(n); 367 } 368 369 auto llvm::ilist_detail::SpecificNodeAccess< 370 typename llvm::ilist_detail::compute_node_options< 371 ::mlir::Operation>::type>::getValuePtr(const node_type *n) 372 -> const_pointer { 373 return NodeAccess::getValuePtr<OptionsT>(n); 374 } 375 376 void llvm::ilist_traits<::mlir::Operation>::deleteNode(Operation *op) { 377 op->destroy(); 378 } 379 380 Block *llvm::ilist_traits<::mlir::Operation>::getContainingBlock() { 381 size_t offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr)))); 382 iplist<Operation> *anchor(static_cast<iplist<Operation> *>(this)); 383 return reinterpret_cast<Block *>(reinterpret_cast<char *>(anchor) - offset); 384 } 385 386 /// This is a trait method invoked when an operation is added to a block. We 387 /// keep the block pointer up to date. 388 void llvm::ilist_traits<::mlir::Operation>::addNodeToList(Operation *op) { 389 assert(!op->getBlock() && "already in an operation block!"); 390 op->block = getContainingBlock(); 391 392 // Invalidate the order on the operation. 393 op->orderIndex = Operation::kInvalidOrderIdx; 394 } 395 396 /// This is a trait method invoked when an operation is removed from a block. 397 /// We keep the block pointer up to date. 398 void llvm::ilist_traits<::mlir::Operation>::removeNodeFromList(Operation *op) { 399 assert(op->block && "not already in an operation block!"); 400 op->block = nullptr; 401 } 402 403 /// This is a trait method invoked when an operation is moved from one block 404 /// to another. We keep the block pointer up to date. 405 void llvm::ilist_traits<::mlir::Operation>::transferNodesFromList( 406 ilist_traits<Operation> &otherList, op_iterator first, op_iterator last) { 407 Block *curParent = getContainingBlock(); 408 409 // Invalidate the ordering of the parent block. 410 curParent->invalidateOpOrder(); 411 412 // If we are transferring operations within the same block, the block 413 // pointer doesn't need to be updated. 414 if (curParent == otherList.getContainingBlock()) 415 return; 416 417 // Update the 'block' member of each operation. 418 for (; first != last; ++first) 419 first->block = curParent; 420 } 421 422 /// Remove this operation (and its descendants) from its Block and delete 423 /// all of them. 424 void Operation::erase() { 425 if (auto *parent = getBlock()) 426 parent->getOperations().erase(this); 427 else 428 destroy(); 429 } 430 431 /// Remove the operation from its parent block, but don't delete it. 432 void Operation::remove() { 433 if (Block *parent = getBlock()) 434 parent->getOperations().remove(this); 435 } 436 437 /// Unlink this operation from its current block and insert it right before 438 /// `existingOp` which may be in the same or another block in the same 439 /// function. 440 void Operation::moveBefore(Operation *existingOp) { 441 moveBefore(existingOp->getBlock(), existingOp->getIterator()); 442 } 443 444 /// Unlink this operation from its current basic block and insert it right 445 /// before `iterator` in the specified basic block. 446 void Operation::moveBefore(Block *block, 447 llvm::iplist<Operation>::iterator iterator) { 448 block->getOperations().splice(iterator, getBlock()->getOperations(), 449 getIterator()); 450 } 451 452 /// Unlink this operation from its current block and insert it right after 453 /// `existingOp` which may be in the same or another block in the same function. 454 void Operation::moveAfter(Operation *existingOp) { 455 moveAfter(existingOp->getBlock(), existingOp->getIterator()); 456 } 457 458 /// Unlink this operation from its current block and insert it right after 459 /// `iterator` in the specified block. 460 void Operation::moveAfter(Block *block, 461 llvm::iplist<Operation>::iterator iterator) { 462 assert(iterator != block->end() && "cannot move after end of block"); 463 moveBefore(block, std::next(iterator)); 464 } 465 466 /// This drops all operand uses from this operation, which is an essential 467 /// step in breaking cyclic dependences between references when they are to 468 /// be deleted. 469 void Operation::dropAllReferences() { 470 for (auto &op : getOpOperands()) 471 op.drop(); 472 473 for (auto ®ion : getRegions()) 474 region.dropAllReferences(); 475 476 for (auto &dest : getBlockOperands()) 477 dest.drop(); 478 } 479 480 /// This drops all uses of any values defined by this operation or its nested 481 /// regions, wherever they are located. 482 void Operation::dropAllDefinedValueUses() { 483 dropAllUses(); 484 485 for (auto ®ion : getRegions()) 486 for (auto &block : region) 487 block.dropAllDefinedValueUses(); 488 } 489 490 void Operation::setSuccessor(Block *block, unsigned index) { 491 assert(index < getNumSuccessors()); 492 getBlockOperands()[index].set(block); 493 } 494 495 /// Attempt to fold this operation using the Op's registered foldHook. 496 LogicalResult Operation::fold(ArrayRef<Attribute> operands, 497 SmallVectorImpl<OpFoldResult> &results) { 498 // If we have a registered operation definition matching this one, use it to 499 // try to constant fold the operation. 500 Optional<RegisteredOperationName> info = getRegisteredInfo(); 501 if (info && succeeded(info->foldHook(this, operands, results))) 502 return success(); 503 504 // Otherwise, fall back on the dialect hook to handle it. 505 Dialect *dialect = getDialect(); 506 if (!dialect) 507 return failure(); 508 509 auto *interface = dyn_cast<DialectFoldInterface>(dialect); 510 if (!interface) 511 return failure(); 512 513 return interface->fold(this, operands, results); 514 } 515 516 /// Emit an error with the op name prefixed, like "'dim' op " which is 517 /// convenient for verifiers. 518 InFlightDiagnostic Operation::emitOpError(const Twine &message) { 519 return emitError() << "'" << getName() << "' op " << message; 520 } 521 522 //===----------------------------------------------------------------------===// 523 // Operation Cloning 524 //===----------------------------------------------------------------------===// 525 526 Operation::CloneOptions::CloneOptions() 527 : cloneRegionsFlag(false), cloneOperandsFlag(false) {} 528 529 Operation::CloneOptions::CloneOptions(bool cloneRegions, bool cloneOperands) 530 : cloneRegionsFlag(cloneRegions), cloneOperandsFlag(cloneOperands) {} 531 532 Operation::CloneOptions Operation::CloneOptions::all() { 533 return CloneOptions().cloneRegions().cloneOperands(); 534 } 535 536 Operation::CloneOptions &Operation::CloneOptions::cloneRegions(bool enable) { 537 cloneRegionsFlag = enable; 538 return *this; 539 } 540 541 Operation::CloneOptions &Operation::CloneOptions::cloneOperands(bool enable) { 542 cloneOperandsFlag = enable; 543 return *this; 544 } 545 546 /// Create a deep copy of this operation but keep the operation regions empty. 547 /// Operands are remapped using `mapper` (if present), and `mapper` is updated 548 /// to contain the results. The `mapResults` flag specifies whether the results 549 /// of the cloned operation should be added to the map. 550 Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper) { 551 return clone(mapper, CloneOptions::all().cloneRegions(false)); 552 } 553 554 Operation *Operation::cloneWithoutRegions() { 555 BlockAndValueMapping mapper; 556 return cloneWithoutRegions(mapper); 557 } 558 559 /// Create a deep copy of this operation, remapping any operands that use 560 /// values outside of the operation using the map that is provided (leaving 561 /// them alone if no entry is present). Replaces references to cloned 562 /// sub-operations to the corresponding operation that is copied, and adds 563 /// those mappings to the map. 564 Operation *Operation::clone(BlockAndValueMapping &mapper, 565 CloneOptions options) { 566 SmallVector<Value, 8> operands; 567 SmallVector<Block *, 2> successors; 568 569 // Remap the operands. 570 if (options.shouldCloneOperands()) { 571 operands.reserve(getNumOperands()); 572 for (auto opValue : getOperands()) 573 operands.push_back(mapper.lookupOrDefault(opValue)); 574 } 575 576 // Remap the successors. 577 successors.reserve(getNumSuccessors()); 578 for (Block *successor : getSuccessors()) 579 successors.push_back(mapper.lookupOrDefault(successor)); 580 581 // Create the new operation. 582 auto *newOp = create(getLoc(), getName(), getResultTypes(), operands, attrs, 583 successors, getNumRegions()); 584 585 // Clone the regions. 586 if (options.shouldCloneRegions()) { 587 for (unsigned i = 0; i != numRegions; ++i) 588 getRegion(i).cloneInto(&newOp->getRegion(i), mapper); 589 } 590 591 // Remember the mapping of any results. 592 for (unsigned i = 0, e = getNumResults(); i != e; ++i) 593 mapper.map(getResult(i), newOp->getResult(i)); 594 595 return newOp; 596 } 597 598 Operation *Operation::clone(CloneOptions options) { 599 BlockAndValueMapping mapper; 600 return clone(mapper, options); 601 } 602 603 //===----------------------------------------------------------------------===// 604 // OpState trait class. 605 //===----------------------------------------------------------------------===// 606 607 // The fallback for the parser is to try for a dialect operation parser. 608 // Otherwise, reject the custom assembly form. 609 ParseResult OpState::parse(OpAsmParser &parser, OperationState &result) { 610 if (auto parseFn = result.name.getDialect()->getParseOperationHook( 611 result.name.getStringRef())) 612 return (*parseFn)(parser, result); 613 return parser.emitError(parser.getNameLoc(), "has no custom assembly form"); 614 } 615 616 // The fallback for the printer is to try for a dialect operation printer. 617 // Otherwise, it prints the generic form. 618 void OpState::print(Operation *op, OpAsmPrinter &p, StringRef defaultDialect) { 619 if (auto printFn = op->getDialect()->getOperationPrinter(op)) { 620 printOpName(op, p, defaultDialect); 621 printFn(op, p); 622 } else { 623 p.printGenericOp(op); 624 } 625 } 626 627 /// Print an operation name, eliding the dialect prefix if necessary and doesn't 628 /// lead to ambiguities. 629 void OpState::printOpName(Operation *op, OpAsmPrinter &p, 630 StringRef defaultDialect) { 631 StringRef name = op->getName().getStringRef(); 632 if (name.startswith((defaultDialect + ".").str()) && name.count('.') == 1) 633 name = name.drop_front(defaultDialect.size() + 1); 634 p.getStream() << name; 635 } 636 637 /// Emit an error about fatal conditions with this operation, reporting up to 638 /// any diagnostic handlers that may be listening. 639 InFlightDiagnostic OpState::emitError(const Twine &message) { 640 return getOperation()->emitError(message); 641 } 642 643 /// Emit an error with the op name prefixed, like "'dim' op " which is 644 /// convenient for verifiers. 645 InFlightDiagnostic OpState::emitOpError(const Twine &message) { 646 return getOperation()->emitOpError(message); 647 } 648 649 /// Emit a warning about this operation, reporting up to any diagnostic 650 /// handlers that may be listening. 651 InFlightDiagnostic OpState::emitWarning(const Twine &message) { 652 return getOperation()->emitWarning(message); 653 } 654 655 /// Emit a remark about this operation, reporting up to any diagnostic 656 /// handlers that may be listening. 657 InFlightDiagnostic OpState::emitRemark(const Twine &message) { 658 return getOperation()->emitRemark(message); 659 } 660 661 //===----------------------------------------------------------------------===// 662 // Op Trait implementations 663 //===----------------------------------------------------------------------===// 664 665 OpFoldResult OpTrait::impl::foldIdempotent(Operation *op) { 666 if (op->getNumOperands() == 1) { 667 auto *argumentOp = op->getOperand(0).getDefiningOp(); 668 if (argumentOp && op->getName() == argumentOp->getName()) { 669 // Replace the outer operation output with the inner operation. 670 return op->getOperand(0); 671 } 672 } else if (op->getOperand(0) == op->getOperand(1)) { 673 return op->getOperand(0); 674 } 675 676 return {}; 677 } 678 679 OpFoldResult OpTrait::impl::foldInvolution(Operation *op) { 680 auto *argumentOp = op->getOperand(0).getDefiningOp(); 681 if (argumentOp && op->getName() == argumentOp->getName()) { 682 // Replace the outer involutions output with inner's input. 683 return argumentOp->getOperand(0); 684 } 685 686 return {}; 687 } 688 689 LogicalResult OpTrait::impl::verifyZeroOperands(Operation *op) { 690 if (op->getNumOperands() != 0) 691 return op->emitOpError() << "requires zero operands"; 692 return success(); 693 } 694 695 LogicalResult OpTrait::impl::verifyOneOperand(Operation *op) { 696 if (op->getNumOperands() != 1) 697 return op->emitOpError() << "requires a single operand"; 698 return success(); 699 } 700 701 LogicalResult OpTrait::impl::verifyNOperands(Operation *op, 702 unsigned numOperands) { 703 if (op->getNumOperands() != numOperands) { 704 return op->emitOpError() << "expected " << numOperands 705 << " operands, but found " << op->getNumOperands(); 706 } 707 return success(); 708 } 709 710 LogicalResult OpTrait::impl::verifyAtLeastNOperands(Operation *op, 711 unsigned numOperands) { 712 if (op->getNumOperands() < numOperands) 713 return op->emitOpError() 714 << "expected " << numOperands << " or more operands, but found " 715 << op->getNumOperands(); 716 return success(); 717 } 718 719 /// If this is a vector type, or a tensor type, return the scalar element type 720 /// that it is built around, otherwise return the type unmodified. 721 static Type getTensorOrVectorElementType(Type type) { 722 if (auto vec = type.dyn_cast<VectorType>()) 723 return vec.getElementType(); 724 725 // Look through tensor<vector<...>> to find the underlying element type. 726 if (auto tensor = type.dyn_cast<TensorType>()) 727 return getTensorOrVectorElementType(tensor.getElementType()); 728 return type; 729 } 730 731 LogicalResult OpTrait::impl::verifyIsIdempotent(Operation *op) { 732 // FIXME: Add back check for no side effects on operation. 733 // Currently adding it would cause the shared library build 734 // to fail since there would be a dependency of IR on SideEffectInterfaces 735 // which is cyclical. 736 return success(); 737 } 738 739 LogicalResult OpTrait::impl::verifyIsInvolution(Operation *op) { 740 // FIXME: Add back check for no side effects on operation. 741 // Currently adding it would cause the shared library build 742 // to fail since there would be a dependency of IR on SideEffectInterfaces 743 // which is cyclical. 744 return success(); 745 } 746 747 LogicalResult 748 OpTrait::impl::verifyOperandsAreSignlessIntegerLike(Operation *op) { 749 for (auto opType : op->getOperandTypes()) { 750 auto type = getTensorOrVectorElementType(opType); 751 if (!type.isSignlessIntOrIndex()) 752 return op->emitOpError() << "requires an integer or index type"; 753 } 754 return success(); 755 } 756 757 LogicalResult OpTrait::impl::verifyOperandsAreFloatLike(Operation *op) { 758 for (auto opType : op->getOperandTypes()) { 759 auto type = getTensorOrVectorElementType(opType); 760 if (!type.isa<FloatType>()) 761 return op->emitOpError("requires a float type"); 762 } 763 return success(); 764 } 765 766 LogicalResult OpTrait::impl::verifySameTypeOperands(Operation *op) { 767 // Zero or one operand always have the "same" type. 768 unsigned nOperands = op->getNumOperands(); 769 if (nOperands < 2) 770 return success(); 771 772 auto type = op->getOperand(0).getType(); 773 for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) 774 if (opType != type) 775 return op->emitOpError() << "requires all operands to have the same type"; 776 return success(); 777 } 778 779 LogicalResult OpTrait::impl::verifyZeroRegions(Operation *op) { 780 if (op->getNumRegions() != 0) 781 return op->emitOpError() << "requires zero regions"; 782 return success(); 783 } 784 785 LogicalResult OpTrait::impl::verifyOneRegion(Operation *op) { 786 if (op->getNumRegions() != 1) 787 return op->emitOpError() << "requires one region"; 788 return success(); 789 } 790 791 LogicalResult OpTrait::impl::verifyNRegions(Operation *op, 792 unsigned numRegions) { 793 if (op->getNumRegions() != numRegions) 794 return op->emitOpError() << "expected " << numRegions << " regions"; 795 return success(); 796 } 797 798 LogicalResult OpTrait::impl::verifyAtLeastNRegions(Operation *op, 799 unsigned numRegions) { 800 if (op->getNumRegions() < numRegions) 801 return op->emitOpError() << "expected " << numRegions << " or more regions"; 802 return success(); 803 } 804 805 LogicalResult OpTrait::impl::verifyZeroResults(Operation *op) { 806 if (op->getNumResults() != 0) 807 return op->emitOpError() << "requires zero results"; 808 return success(); 809 } 810 811 LogicalResult OpTrait::impl::verifyOneResult(Operation *op) { 812 if (op->getNumResults() != 1) 813 return op->emitOpError() << "requires one result"; 814 return success(); 815 } 816 817 LogicalResult OpTrait::impl::verifyNResults(Operation *op, 818 unsigned numOperands) { 819 if (op->getNumResults() != numOperands) 820 return op->emitOpError() << "expected " << numOperands << " results"; 821 return success(); 822 } 823 824 LogicalResult OpTrait::impl::verifyAtLeastNResults(Operation *op, 825 unsigned numOperands) { 826 if (op->getNumResults() < numOperands) 827 return op->emitOpError() 828 << "expected " << numOperands << " or more results"; 829 return success(); 830 } 831 832 LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) { 833 if (failed(verifyAtLeastNOperands(op, 1))) 834 return failure(); 835 836 if (failed(verifyCompatibleShapes(op->getOperandTypes()))) 837 return op->emitOpError() << "requires the same shape for all operands"; 838 839 return success(); 840 } 841 842 LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) { 843 if (failed(verifyAtLeastNOperands(op, 1)) || 844 failed(verifyAtLeastNResults(op, 1))) 845 return failure(); 846 847 SmallVector<Type, 8> types(op->getOperandTypes()); 848 types.append(llvm::to_vector<4>(op->getResultTypes())); 849 850 if (failed(verifyCompatibleShapes(types))) 851 return op->emitOpError() 852 << "requires the same shape for all operands and results"; 853 854 return success(); 855 } 856 857 LogicalResult OpTrait::impl::verifySameOperandsElementType(Operation *op) { 858 if (failed(verifyAtLeastNOperands(op, 1))) 859 return failure(); 860 auto elementType = getElementTypeOrSelf(op->getOperand(0)); 861 862 for (auto operand : llvm::drop_begin(op->getOperands(), 1)) { 863 if (getElementTypeOrSelf(operand) != elementType) 864 return op->emitOpError("requires the same element type for all operands"); 865 } 866 867 return success(); 868 } 869 870 LogicalResult 871 OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) { 872 if (failed(verifyAtLeastNOperands(op, 1)) || 873 failed(verifyAtLeastNResults(op, 1))) 874 return failure(); 875 876 auto elementType = getElementTypeOrSelf(op->getResult(0)); 877 878 // Verify result element type matches first result's element type. 879 for (auto result : llvm::drop_begin(op->getResults(), 1)) { 880 if (getElementTypeOrSelf(result) != elementType) 881 return op->emitOpError( 882 "requires the same element type for all operands and results"); 883 } 884 885 // Verify operand's element type matches first result's element type. 886 for (auto operand : op->getOperands()) { 887 if (getElementTypeOrSelf(operand) != elementType) 888 return op->emitOpError( 889 "requires the same element type for all operands and results"); 890 } 891 892 return success(); 893 } 894 895 LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) { 896 if (failed(verifyAtLeastNOperands(op, 1)) || 897 failed(verifyAtLeastNResults(op, 1))) 898 return failure(); 899 900 auto type = op->getResult(0).getType(); 901 auto elementType = getElementTypeOrSelf(type); 902 for (auto resultType : llvm::drop_begin(op->getResultTypes())) { 903 if (getElementTypeOrSelf(resultType) != elementType || 904 failed(verifyCompatibleShape(resultType, type))) 905 return op->emitOpError() 906 << "requires the same type for all operands and results"; 907 } 908 for (auto opType : op->getOperandTypes()) { 909 if (getElementTypeOrSelf(opType) != elementType || 910 failed(verifyCompatibleShape(opType, type))) 911 return op->emitOpError() 912 << "requires the same type for all operands and results"; 913 } 914 return success(); 915 } 916 917 LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) { 918 Block *block = op->getBlock(); 919 // Verify that the operation is at the end of the respective parent block. 920 if (!block || &block->back() != op) 921 return op->emitOpError("must be the last operation in the parent block"); 922 return success(); 923 } 924 925 static LogicalResult verifyTerminatorSuccessors(Operation *op) { 926 auto *parent = op->getParentRegion(); 927 928 // Verify that the operands lines up with the BB arguments in the successor. 929 for (Block *succ : op->getSuccessors()) 930 if (succ->getParent() != parent) 931 return op->emitError("reference to block defined in another region"); 932 return success(); 933 } 934 935 LogicalResult OpTrait::impl::verifyZeroSuccessors(Operation *op) { 936 if (op->getNumSuccessors() != 0) { 937 return op->emitOpError("requires 0 successors but found ") 938 << op->getNumSuccessors(); 939 } 940 return success(); 941 } 942 943 LogicalResult OpTrait::impl::verifyOneSuccessor(Operation *op) { 944 if (op->getNumSuccessors() != 1) { 945 return op->emitOpError("requires 1 successor but found ") 946 << op->getNumSuccessors(); 947 } 948 return verifyTerminatorSuccessors(op); 949 } 950 LogicalResult OpTrait::impl::verifyNSuccessors(Operation *op, 951 unsigned numSuccessors) { 952 if (op->getNumSuccessors() != numSuccessors) { 953 return op->emitOpError("requires ") 954 << numSuccessors << " successors but found " 955 << op->getNumSuccessors(); 956 } 957 return verifyTerminatorSuccessors(op); 958 } 959 LogicalResult OpTrait::impl::verifyAtLeastNSuccessors(Operation *op, 960 unsigned numSuccessors) { 961 if (op->getNumSuccessors() < numSuccessors) { 962 return op->emitOpError("requires at least ") 963 << numSuccessors << " successors but found " 964 << op->getNumSuccessors(); 965 } 966 return verifyTerminatorSuccessors(op); 967 } 968 969 LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) { 970 for (auto resultType : op->getResultTypes()) { 971 auto elementType = getTensorOrVectorElementType(resultType); 972 bool isBoolType = elementType.isInteger(1); 973 if (!isBoolType) 974 return op->emitOpError() << "requires a bool result type"; 975 } 976 977 return success(); 978 } 979 980 LogicalResult OpTrait::impl::verifyResultsAreFloatLike(Operation *op) { 981 for (auto resultType : op->getResultTypes()) 982 if (!getTensorOrVectorElementType(resultType).isa<FloatType>()) 983 return op->emitOpError() << "requires a floating point type"; 984 985 return success(); 986 } 987 988 LogicalResult 989 OpTrait::impl::verifyResultsAreSignlessIntegerLike(Operation *op) { 990 for (auto resultType : op->getResultTypes()) 991 if (!getTensorOrVectorElementType(resultType).isSignlessIntOrIndex()) 992 return op->emitOpError() << "requires an integer or index type"; 993 return success(); 994 } 995 996 LogicalResult OpTrait::impl::verifyValueSizeAttr(Operation *op, 997 StringRef attrName, 998 StringRef valueGroupName, 999 size_t expectedCount) { 1000 auto sizeAttr = op->getAttrOfType<DenseIntElementsAttr>(attrName); 1001 if (!sizeAttr) 1002 return op->emitOpError("requires 1D i32 elements attribute '") 1003 << attrName << "'"; 1004 1005 auto sizeAttrType = sizeAttr.getType(); 1006 if (sizeAttrType.getRank() != 1 || 1007 !sizeAttrType.getElementType().isInteger(32)) 1008 return op->emitOpError("requires 1D i32 elements attribute '") 1009 << attrName << "'"; 1010 1011 if (llvm::any_of(sizeAttr.getValues<APInt>(), [](const APInt &element) { 1012 return !element.isNonNegative(); 1013 })) 1014 return op->emitOpError("'") 1015 << attrName << "' attribute cannot have negative elements"; 1016 1017 size_t totalCount = std::accumulate( 1018 sizeAttr.begin(), sizeAttr.end(), 0, 1019 [](unsigned all, const APInt &one) { return all + one.getZExtValue(); }); 1020 1021 if (totalCount != expectedCount) 1022 return op->emitOpError() 1023 << valueGroupName << " count (" << expectedCount 1024 << ") does not match with the total size (" << totalCount 1025 << ") specified in attribute '" << attrName << "'"; 1026 return success(); 1027 } 1028 1029 LogicalResult OpTrait::impl::verifyOperandSizeAttr(Operation *op, 1030 StringRef attrName) { 1031 return verifyValueSizeAttr(op, attrName, "operand", op->getNumOperands()); 1032 } 1033 1034 LogicalResult OpTrait::impl::verifyResultSizeAttr(Operation *op, 1035 StringRef attrName) { 1036 return verifyValueSizeAttr(op, attrName, "result", op->getNumResults()); 1037 } 1038 1039 LogicalResult OpTrait::impl::verifyNoRegionArguments(Operation *op) { 1040 for (Region ®ion : op->getRegions()) { 1041 if (region.empty()) 1042 continue; 1043 1044 if (region.getNumArguments() != 0) { 1045 if (op->getNumRegions() > 1) 1046 return op->emitOpError("region #") 1047 << region.getRegionNumber() << " should have no arguments"; 1048 return op->emitOpError("region should have no arguments"); 1049 } 1050 } 1051 return success(); 1052 } 1053 1054 LogicalResult OpTrait::impl::verifyElementwise(Operation *op) { 1055 auto isMappableType = [](Type type) { 1056 return type.isa<VectorType, TensorType>(); 1057 }; 1058 auto resultMappableTypes = llvm::to_vector<1>( 1059 llvm::make_filter_range(op->getResultTypes(), isMappableType)); 1060 auto operandMappableTypes = llvm::to_vector<2>( 1061 llvm::make_filter_range(op->getOperandTypes(), isMappableType)); 1062 1063 // If the op only has scalar operand/result types, then we have nothing to 1064 // check. 1065 if (resultMappableTypes.empty() && operandMappableTypes.empty()) 1066 return success(); 1067 1068 if (!resultMappableTypes.empty() && operandMappableTypes.empty()) 1069 return op->emitOpError("if a result is non-scalar, then at least one " 1070 "operand must be non-scalar"); 1071 1072 assert(!operandMappableTypes.empty()); 1073 1074 if (resultMappableTypes.empty()) 1075 return op->emitOpError("if an operand is non-scalar, then there must be at " 1076 "least one non-scalar result"); 1077 1078 if (resultMappableTypes.size() != op->getNumResults()) 1079 return op->emitOpError( 1080 "if an operand is non-scalar, then all results must be non-scalar"); 1081 1082 SmallVector<Type, 4> types = llvm::to_vector<2>( 1083 llvm::concat<Type>(operandMappableTypes, resultMappableTypes)); 1084 TypeID expectedBaseTy = types.front().getTypeID(); 1085 if (!llvm::all_of(types, 1086 [&](Type t) { return t.getTypeID() == expectedBaseTy; }) || 1087 failed(verifyCompatibleShapes(types))) { 1088 return op->emitOpError() << "all non-scalar operands/results must have the " 1089 "same shape and base type"; 1090 } 1091 1092 return success(); 1093 } 1094 1095 /// Check for any values used by operations regions attached to the 1096 /// specified "IsIsolatedFromAbove" operation defined outside of it. 1097 LogicalResult OpTrait::impl::verifyIsIsolatedFromAbove(Operation *isolatedOp) { 1098 assert(isolatedOp->hasTrait<OpTrait::IsIsolatedFromAbove>() && 1099 "Intended to check IsolatedFromAbove ops"); 1100 1101 // List of regions to analyze. Each region is processed independently, with 1102 // respect to the common `limit` region, so we can look at them in any order. 1103 // Therefore, use a simple vector and push/pop back the current region. 1104 SmallVector<Region *, 8> pendingRegions; 1105 for (auto ®ion : isolatedOp->getRegions()) { 1106 pendingRegions.push_back(®ion); 1107 1108 // Traverse all operations in the region. 1109 while (!pendingRegions.empty()) { 1110 for (Operation &op : pendingRegions.pop_back_val()->getOps()) { 1111 for (Value operand : op.getOperands()) { 1112 // Check that any value that is used by an operation is defined in the 1113 // same region as either an operation result. 1114 auto *operandRegion = operand.getParentRegion(); 1115 if (!operandRegion) 1116 return op.emitError("operation's operand is unlinked"); 1117 if (!region.isAncestor(operandRegion)) { 1118 return op.emitOpError("using value defined outside the region") 1119 .attachNote(isolatedOp->getLoc()) 1120 << "required by region isolation constraints"; 1121 } 1122 } 1123 1124 // Schedule any regions in the operation for further checking. Don't 1125 // recurse into other IsolatedFromAbove ops, because they will check 1126 // themselves. 1127 if (op.getNumRegions() && 1128 !op.hasTrait<OpTrait::IsIsolatedFromAbove>()) { 1129 for (Region &subRegion : op.getRegions()) 1130 pendingRegions.push_back(&subRegion); 1131 } 1132 } 1133 } 1134 } 1135 1136 return success(); 1137 } 1138 1139 bool OpTrait::hasElementwiseMappableTraits(Operation *op) { 1140 return op->hasTrait<Elementwise>() && op->hasTrait<Scalarizable>() && 1141 op->hasTrait<Vectorizable>() && op->hasTrait<Tensorizable>(); 1142 } 1143 1144 //===----------------------------------------------------------------------===// 1145 // CastOpInterface 1146 //===----------------------------------------------------------------------===// 1147 1148 /// Attempt to fold the given cast operation. 1149 LogicalResult 1150 impl::foldCastInterfaceOp(Operation *op, ArrayRef<Attribute> attrOperands, 1151 SmallVectorImpl<OpFoldResult> &foldResults) { 1152 OperandRange operands = op->getOperands(); 1153 if (operands.empty()) 1154 return failure(); 1155 ResultRange results = op->getResults(); 1156 1157 // Check for the case where the input and output types match 1-1. 1158 if (operands.getTypes() == results.getTypes()) { 1159 foldResults.append(operands.begin(), operands.end()); 1160 return success(); 1161 } 1162 1163 return failure(); 1164 } 1165 1166 /// Attempt to verify the given cast operation. 1167 LogicalResult impl::verifyCastInterfaceOp( 1168 Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible) { 1169 auto resultTypes = op->getResultTypes(); 1170 if (llvm::empty(resultTypes)) 1171 return op->emitOpError() 1172 << "expected at least one result for cast operation"; 1173 1174 auto operandTypes = op->getOperandTypes(); 1175 if (!areCastCompatible(operandTypes, resultTypes)) { 1176 InFlightDiagnostic diag = op->emitOpError("operand type"); 1177 if (llvm::empty(operandTypes)) 1178 diag << "s []"; 1179 else if (llvm::size(operandTypes) == 1) 1180 diag << " " << *operandTypes.begin(); 1181 else 1182 diag << "s " << operandTypes; 1183 return diag << " and result type" << (resultTypes.size() == 1 ? " " : "s ") 1184 << resultTypes << " are cast incompatible"; 1185 } 1186 1187 return success(); 1188 } 1189 1190 //===----------------------------------------------------------------------===// 1191 // Misc. utils 1192 //===----------------------------------------------------------------------===// 1193 1194 /// Insert an operation, generated by `buildTerminatorOp`, at the end of the 1195 /// region's only block if it does not have a terminator already. If the region 1196 /// is empty, insert a new block first. `buildTerminatorOp` should return the 1197 /// terminator operation to insert. 1198 void impl::ensureRegionTerminator( 1199 Region ®ion, OpBuilder &builder, Location loc, 1200 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { 1201 OpBuilder::InsertionGuard guard(builder); 1202 if (region.empty()) 1203 builder.createBlock(®ion); 1204 1205 Block &block = region.back(); 1206 if (!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>()) 1207 return; 1208 1209 builder.setInsertionPointToEnd(&block); 1210 builder.insert(buildTerminatorOp(builder, loc)); 1211 } 1212 1213 /// Create a simple OpBuilder and forward to the OpBuilder version of this 1214 /// function. 1215 void impl::ensureRegionTerminator( 1216 Region ®ion, Builder &builder, Location loc, 1217 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { 1218 OpBuilder opBuilder(builder.getContext()); 1219 ensureRegionTerminator(region, opBuilder, loc, buildTerminatorOp); 1220 } 1221