1 //===- Operation.cpp - Operation support code -----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Operation.h" 10 #include "mlir/IR/BlockAndValueMapping.h" 11 #include "mlir/IR/BuiltinTypes.h" 12 #include "mlir/IR/Dialect.h" 13 #include "mlir/IR/OpImplementation.h" 14 #include "mlir/IR/PatternMatch.h" 15 #include "mlir/IR/TypeUtilities.h" 16 #include "mlir/Interfaces/FoldInterfaces.h" 17 #include "llvm/ADT/StringExtras.h" 18 #include <numeric> 19 20 using namespace mlir; 21 22 //===----------------------------------------------------------------------===// 23 // Operation 24 //===----------------------------------------------------------------------===// 25 26 /// Create a new Operation with the specific fields. 27 Operation *Operation::create(Location location, OperationName name, 28 TypeRange resultTypes, ValueRange operands, 29 ArrayRef<NamedAttribute> attributes, 30 BlockRange successors, unsigned numRegions) { 31 return create(location, name, resultTypes, operands, 32 DictionaryAttr::get(location.getContext(), attributes), 33 successors, numRegions); 34 } 35 36 /// Create a new Operation from operation state. 37 Operation *Operation::create(const OperationState &state) { 38 return create(state.location, state.name, state.types, state.operands, 39 state.attributes.getDictionary(state.getContext()), 40 state.successors, state.regions); 41 } 42 43 /// Create a new Operation with the specific fields. 44 Operation *Operation::create(Location location, OperationName name, 45 TypeRange resultTypes, ValueRange operands, 46 DictionaryAttr attributes, BlockRange successors, 47 RegionRange regions) { 48 unsigned numRegions = regions.size(); 49 Operation *op = create(location, name, resultTypes, operands, attributes, 50 successors, numRegions); 51 for (unsigned i = 0; i < numRegions; ++i) 52 if (regions[i]) 53 op->getRegion(i).takeBody(*regions[i]); 54 return op; 55 } 56 57 /// Overload of create that takes an existing DictionaryAttr to avoid 58 /// unnecessarily uniquing a list of attributes. 59 Operation *Operation::create(Location location, OperationName name, 60 TypeRange resultTypes, ValueRange operands, 61 DictionaryAttr attributes, BlockRange successors, 62 unsigned numRegions) { 63 assert(llvm::all_of(resultTypes, [](Type t) { return t; }) && 64 "unexpected null result type"); 65 66 // We only need to allocate additional memory for a subset of results. 67 unsigned numTrailingResults = OpResult::getNumTrailing(resultTypes.size()); 68 unsigned numInlineResults = OpResult::getNumInline(resultTypes.size()); 69 unsigned numSuccessors = successors.size(); 70 unsigned numOperands = operands.size(); 71 unsigned numResults = resultTypes.size(); 72 73 // If the operation is known to have no operands, don't allocate an operand 74 // storage. 75 bool needsOperandStorage = 76 operands.empty() ? !name.hasTrait<OpTrait::ZeroOperands>() : true; 77 78 // Compute the byte size for the operation and the operand storage. This takes 79 // into account the size of the operation, its trailing objects, and its 80 // prefixed objects. 81 size_t byteSize = 82 totalSizeToAlloc<detail::OperandStorage, BlockOperand, Region, OpOperand>( 83 needsOperandStorage ? 1 : 0, numSuccessors, numRegions, numOperands); 84 size_t prefixByteSize = llvm::alignTo( 85 Operation::prefixAllocSize(numTrailingResults, numInlineResults), 86 alignof(Operation)); 87 char *mallocMem = reinterpret_cast<char *>(malloc(byteSize + prefixByteSize)); 88 void *rawMem = mallocMem + prefixByteSize; 89 90 // Create the new Operation. 91 Operation *op = 92 ::new (rawMem) Operation(location, name, numResults, numSuccessors, 93 numRegions, attributes, needsOperandStorage); 94 95 assert((numSuccessors == 0 || op->mightHaveTrait<OpTrait::IsTerminator>()) && 96 "unexpected successors in a non-terminator operation"); 97 98 // Initialize the results. 99 auto resultTypeIt = resultTypes.begin(); 100 for (unsigned i = 0; i < numInlineResults; ++i, ++resultTypeIt) 101 new (op->getInlineOpResult(i)) detail::InlineOpResult(*resultTypeIt, i); 102 for (unsigned i = 0; i < numTrailingResults; ++i, ++resultTypeIt) { 103 new (op->getOutOfLineOpResult(i)) 104 detail::OutOfLineOpResult(*resultTypeIt, i); 105 } 106 107 // Initialize the regions. 108 for (unsigned i = 0; i != numRegions; ++i) 109 new (&op->getRegion(i)) Region(op); 110 111 // Initialize the operands. 112 if (needsOperandStorage) { 113 new (&op->getOperandStorage()) detail::OperandStorage( 114 op, op->getTrailingObjects<OpOperand>(), operands); 115 } 116 117 // Initialize the successors. 118 auto blockOperands = op->getBlockOperands(); 119 for (unsigned i = 0; i != numSuccessors; ++i) 120 new (&blockOperands[i]) BlockOperand(op, successors[i]); 121 122 return op; 123 } 124 125 Operation::Operation(Location location, OperationName name, unsigned numResults, 126 unsigned numSuccessors, unsigned numRegions, 127 DictionaryAttr attributes, bool hasOperandStorage) 128 : location(location), numResults(numResults), numSuccs(numSuccessors), 129 numRegions(numRegions), hasOperandStorage(hasOperandStorage), name(name), 130 attrs(attributes) { 131 assert(attributes && "unexpected null attribute dictionary"); 132 #ifndef NDEBUG 133 if (!getDialect() && !getContext()->allowsUnregisteredDialects()) 134 llvm::report_fatal_error( 135 name.getStringRef() + 136 " created with unregistered dialect. If this is intended, please call " 137 "allowUnregisteredDialects() on the MLIRContext, or use " 138 "-allow-unregistered-dialect with the MLIR tool used."); 139 #endif 140 } 141 142 // Operations are deleted through the destroy() member because they are 143 // allocated via malloc. 144 Operation::~Operation() { 145 assert(block == nullptr && "operation destroyed but still in a block"); 146 #ifndef NDEBUG 147 if (!use_empty()) { 148 { 149 InFlightDiagnostic diag = 150 emitOpError("operation destroyed but still has uses"); 151 for (Operation *user : getUsers()) 152 diag.attachNote(user->getLoc()) << "- use: " << *user << "\n"; 153 } 154 llvm::report_fatal_error("operation destroyed but still has uses"); 155 } 156 #endif 157 // Explicitly run the destructors for the operands. 158 if (hasOperandStorage) 159 getOperandStorage().~OperandStorage(); 160 161 // Explicitly run the destructors for the successors. 162 for (auto &successor : getBlockOperands()) 163 successor.~BlockOperand(); 164 165 // Explicitly destroy the regions. 166 for (auto ®ion : getRegions()) 167 region.~Region(); 168 } 169 170 /// Destroy this operation or one of its subclasses. 171 void Operation::destroy() { 172 // Operations may have additional prefixed allocation, which needs to be 173 // accounted for here when computing the address to free. 174 char *rawMem = reinterpret_cast<char *>(this) - 175 llvm::alignTo(prefixAllocSize(), alignof(Operation)); 176 this->~Operation(); 177 free(rawMem); 178 } 179 180 /// Return true if this operation is a proper ancestor of the `other` 181 /// operation. 182 bool Operation::isProperAncestor(Operation *other) { 183 while ((other = other->getParentOp())) 184 if (this == other) 185 return true; 186 return false; 187 } 188 189 /// Replace any uses of 'from' with 'to' within this operation. 190 void Operation::replaceUsesOfWith(Value from, Value to) { 191 if (from == to) 192 return; 193 for (auto &operand : getOpOperands()) 194 if (operand.get() == from) 195 operand.set(to); 196 } 197 198 /// Replace the current operands of this operation with the ones provided in 199 /// 'operands'. 200 void Operation::setOperands(ValueRange operands) { 201 if (LLVM_LIKELY(hasOperandStorage)) 202 return getOperandStorage().setOperands(this, operands); 203 assert(operands.empty() && "setting operands without an operand storage"); 204 } 205 206 /// Replace the operands beginning at 'start' and ending at 'start' + 'length' 207 /// with the ones provided in 'operands'. 'operands' may be smaller or larger 208 /// than the range pointed to by 'start'+'length'. 209 void Operation::setOperands(unsigned start, unsigned length, 210 ValueRange operands) { 211 assert((start + length) <= getNumOperands() && 212 "invalid operand range specified"); 213 if (LLVM_LIKELY(hasOperandStorage)) 214 return getOperandStorage().setOperands(this, start, length, operands); 215 assert(operands.empty() && "setting operands without an operand storage"); 216 } 217 218 /// Insert the given operands into the operand list at the given 'index'. 219 void Operation::insertOperands(unsigned index, ValueRange operands) { 220 if (LLVM_LIKELY(hasOperandStorage)) 221 return setOperands(index, /*length=*/0, operands); 222 assert(operands.empty() && "inserting operands without an operand storage"); 223 } 224 225 //===----------------------------------------------------------------------===// 226 // Diagnostics 227 //===----------------------------------------------------------------------===// 228 229 /// Emit an error about fatal conditions with this operation, reporting up to 230 /// any diagnostic handlers that may be listening. 231 InFlightDiagnostic Operation::emitError(const Twine &message) { 232 InFlightDiagnostic diag = mlir::emitError(getLoc(), message); 233 if (getContext()->shouldPrintOpOnDiagnostic()) { 234 diag.attachNote(getLoc()) 235 .append("see current operation: ") 236 .appendOp(*this, OpPrintingFlags().printGenericOpForm()); 237 } 238 return diag; 239 } 240 241 /// Emit a warning about this operation, reporting up to any diagnostic 242 /// handlers that may be listening. 243 InFlightDiagnostic Operation::emitWarning(const Twine &message) { 244 InFlightDiagnostic diag = mlir::emitWarning(getLoc(), message); 245 if (getContext()->shouldPrintOpOnDiagnostic()) 246 diag.attachNote(getLoc()) << "see current operation: " << *this; 247 return diag; 248 } 249 250 /// Emit a remark about this operation, reporting up to any diagnostic 251 /// handlers that may be listening. 252 InFlightDiagnostic Operation::emitRemark(const Twine &message) { 253 InFlightDiagnostic diag = mlir::emitRemark(getLoc(), message); 254 if (getContext()->shouldPrintOpOnDiagnostic()) 255 diag.attachNote(getLoc()) << "see current operation: " << *this; 256 return diag; 257 } 258 259 //===----------------------------------------------------------------------===// 260 // Operation Ordering 261 //===----------------------------------------------------------------------===// 262 263 constexpr unsigned Operation::kInvalidOrderIdx; 264 constexpr unsigned Operation::kOrderStride; 265 266 /// Given an operation 'other' that is within the same parent block, return 267 /// whether the current operation is before 'other' in the operation list 268 /// of the parent block. 269 /// Note: This function has an average complexity of O(1), but worst case may 270 /// take O(N) where N is the number of operations within the parent block. 271 bool Operation::isBeforeInBlock(Operation *other) { 272 assert(block && "Operations without parent blocks have no order."); 273 assert(other && other->block == block && 274 "Expected other operation to have the same parent block."); 275 // If the order of the block is already invalid, directly recompute the 276 // parent. 277 if (!block->isOpOrderValid()) { 278 block->recomputeOpOrder(); 279 } else { 280 // Update the order either operation if necessary. 281 updateOrderIfNecessary(); 282 other->updateOrderIfNecessary(); 283 } 284 285 return orderIndex < other->orderIndex; 286 } 287 288 /// Update the order index of this operation of this operation if necessary, 289 /// potentially recomputing the order of the parent block. 290 void Operation::updateOrderIfNecessary() { 291 assert(block && "expected valid parent"); 292 293 // If the order is valid for this operation there is nothing to do. 294 if (hasValidOrder()) 295 return; 296 Operation *blockFront = &block->front(); 297 Operation *blockBack = &block->back(); 298 299 // This method is expected to only be invoked on blocks with more than one 300 // operation. 301 assert(blockFront != blockBack && "expected more than one operation"); 302 303 // If the operation is at the end of the block. 304 if (this == blockBack) { 305 Operation *prevNode = getPrevNode(); 306 if (!prevNode->hasValidOrder()) 307 return block->recomputeOpOrder(); 308 309 // Add the stride to the previous operation. 310 orderIndex = prevNode->orderIndex + kOrderStride; 311 return; 312 } 313 314 // If this is the first operation try to use the next operation to compute the 315 // ordering. 316 if (this == blockFront) { 317 Operation *nextNode = getNextNode(); 318 if (!nextNode->hasValidOrder()) 319 return block->recomputeOpOrder(); 320 // There is no order to give this operation. 321 if (nextNode->orderIndex == 0) 322 return block->recomputeOpOrder(); 323 324 // If we can't use the stride, just take the middle value left. This is safe 325 // because we know there is at least one valid index to assign to. 326 if (nextNode->orderIndex <= kOrderStride) 327 orderIndex = (nextNode->orderIndex / 2); 328 else 329 orderIndex = kOrderStride; 330 return; 331 } 332 333 // Otherwise, this operation is between two others. Place this operation in 334 // the middle of the previous and next if possible. 335 Operation *prevNode = getPrevNode(), *nextNode = getNextNode(); 336 if (!prevNode->hasValidOrder() || !nextNode->hasValidOrder()) 337 return block->recomputeOpOrder(); 338 unsigned prevOrder = prevNode->orderIndex, nextOrder = nextNode->orderIndex; 339 340 // Check to see if there is a valid order between the two. 341 if (prevOrder + 1 == nextOrder) 342 return block->recomputeOpOrder(); 343 orderIndex = prevOrder + ((nextOrder - prevOrder) / 2); 344 } 345 346 //===----------------------------------------------------------------------===// 347 // ilist_traits for Operation 348 //===----------------------------------------------------------------------===// 349 350 auto llvm::ilist_detail::SpecificNodeAccess< 351 typename llvm::ilist_detail::compute_node_options< 352 ::mlir::Operation>::type>::getNodePtr(pointer n) -> node_type * { 353 return NodeAccess::getNodePtr<OptionsT>(n); 354 } 355 356 auto llvm::ilist_detail::SpecificNodeAccess< 357 typename llvm::ilist_detail::compute_node_options< 358 ::mlir::Operation>::type>::getNodePtr(const_pointer n) 359 -> const node_type * { 360 return NodeAccess::getNodePtr<OptionsT>(n); 361 } 362 363 auto llvm::ilist_detail::SpecificNodeAccess< 364 typename llvm::ilist_detail::compute_node_options< 365 ::mlir::Operation>::type>::getValuePtr(node_type *n) -> pointer { 366 return NodeAccess::getValuePtr<OptionsT>(n); 367 } 368 369 auto llvm::ilist_detail::SpecificNodeAccess< 370 typename llvm::ilist_detail::compute_node_options< 371 ::mlir::Operation>::type>::getValuePtr(const node_type *n) 372 -> const_pointer { 373 return NodeAccess::getValuePtr<OptionsT>(n); 374 } 375 376 void llvm::ilist_traits<::mlir::Operation>::deleteNode(Operation *op) { 377 op->destroy(); 378 } 379 380 Block *llvm::ilist_traits<::mlir::Operation>::getContainingBlock() { 381 size_t offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr)))); 382 iplist<Operation> *anchor(static_cast<iplist<Operation> *>(this)); 383 return reinterpret_cast<Block *>(reinterpret_cast<char *>(anchor) - offset); 384 } 385 386 /// This is a trait method invoked when an operation is added to a block. We 387 /// keep the block pointer up to date. 388 void llvm::ilist_traits<::mlir::Operation>::addNodeToList(Operation *op) { 389 assert(!op->getBlock() && "already in an operation block!"); 390 op->block = getContainingBlock(); 391 392 // Invalidate the order on the operation. 393 op->orderIndex = Operation::kInvalidOrderIdx; 394 } 395 396 /// This is a trait method invoked when an operation is removed from a block. 397 /// We keep the block pointer up to date. 398 void llvm::ilist_traits<::mlir::Operation>::removeNodeFromList(Operation *op) { 399 assert(op->block && "not already in an operation block!"); 400 op->block = nullptr; 401 } 402 403 /// This is a trait method invoked when an operation is moved from one block 404 /// to another. We keep the block pointer up to date. 405 void llvm::ilist_traits<::mlir::Operation>::transferNodesFromList( 406 ilist_traits<Operation> &otherList, op_iterator first, op_iterator last) { 407 Block *curParent = getContainingBlock(); 408 409 // Invalidate the ordering of the parent block. 410 curParent->invalidateOpOrder(); 411 412 // If we are transferring operations within the same block, the block 413 // pointer doesn't need to be updated. 414 if (curParent == otherList.getContainingBlock()) 415 return; 416 417 // Update the 'block' member of each operation. 418 for (; first != last; ++first) 419 first->block = curParent; 420 } 421 422 /// Remove this operation (and its descendants) from its Block and delete 423 /// all of them. 424 void Operation::erase() { 425 if (auto *parent = getBlock()) 426 parent->getOperations().erase(this); 427 else 428 destroy(); 429 } 430 431 /// Remove the operation from its parent block, but don't delete it. 432 void Operation::remove() { 433 if (Block *parent = getBlock()) 434 parent->getOperations().remove(this); 435 } 436 437 /// Unlink this operation from its current block and insert it right before 438 /// `existingOp` which may be in the same or another block in the same 439 /// function. 440 void Operation::moveBefore(Operation *existingOp) { 441 moveBefore(existingOp->getBlock(), existingOp->getIterator()); 442 } 443 444 /// Unlink this operation from its current basic block and insert it right 445 /// before `iterator` in the specified basic block. 446 void Operation::moveBefore(Block *block, 447 llvm::iplist<Operation>::iterator iterator) { 448 block->getOperations().splice(iterator, getBlock()->getOperations(), 449 getIterator()); 450 } 451 452 /// Unlink this operation from its current block and insert it right after 453 /// `existingOp` which may be in the same or another block in the same function. 454 void Operation::moveAfter(Operation *existingOp) { 455 moveAfter(existingOp->getBlock(), existingOp->getIterator()); 456 } 457 458 /// Unlink this operation from its current block and insert it right after 459 /// `iterator` in the specified block. 460 void Operation::moveAfter(Block *block, 461 llvm::iplist<Operation>::iterator iterator) { 462 assert(iterator != block->end() && "cannot move after end of block"); 463 moveBefore(block, std::next(iterator)); 464 } 465 466 /// This drops all operand uses from this operation, which is an essential 467 /// step in breaking cyclic dependences between references when they are to 468 /// be deleted. 469 void Operation::dropAllReferences() { 470 for (auto &op : getOpOperands()) 471 op.drop(); 472 473 for (auto ®ion : getRegions()) 474 region.dropAllReferences(); 475 476 for (auto &dest : getBlockOperands()) 477 dest.drop(); 478 } 479 480 /// This drops all uses of any values defined by this operation or its nested 481 /// regions, wherever they are located. 482 void Operation::dropAllDefinedValueUses() { 483 dropAllUses(); 484 485 for (auto ®ion : getRegions()) 486 for (auto &block : region) 487 block.dropAllDefinedValueUses(); 488 } 489 490 void Operation::setSuccessor(Block *block, unsigned index) { 491 assert(index < getNumSuccessors()); 492 getBlockOperands()[index].set(block); 493 } 494 495 /// Attempt to fold this operation using the Op's registered foldHook. 496 LogicalResult Operation::fold(ArrayRef<Attribute> operands, 497 SmallVectorImpl<OpFoldResult> &results) { 498 // If we have a registered operation definition matching this one, use it to 499 // try to constant fold the operation. 500 Optional<RegisteredOperationName> info = getRegisteredInfo(); 501 if (info && succeeded(info->foldHook(this, operands, results))) 502 return success(); 503 504 // Otherwise, fall back on the dialect hook to handle it. 505 Dialect *dialect = getDialect(); 506 if (!dialect) 507 return failure(); 508 509 auto *interface = dyn_cast<DialectFoldInterface>(dialect); 510 if (!interface) 511 return failure(); 512 513 return interface->fold(this, operands, results); 514 } 515 516 /// Emit an error with the op name prefixed, like "'dim' op " which is 517 /// convenient for verifiers. 518 InFlightDiagnostic Operation::emitOpError(const Twine &message) { 519 return emitError() << "'" << getName() << "' op " << message; 520 } 521 522 //===----------------------------------------------------------------------===// 523 // Operation Cloning 524 //===----------------------------------------------------------------------===// 525 526 Operation::CloneOptions::CloneOptions() 527 : cloneRegionsFlag(false), cloneOperandsFlag(false) {} 528 529 Operation::CloneOptions::CloneOptions(bool cloneRegions, bool cloneOperands) 530 : cloneRegionsFlag(cloneRegions), cloneOperandsFlag(cloneOperands) {} 531 532 Operation::CloneOptions Operation::CloneOptions::all() { 533 return CloneOptions().cloneRegions().cloneOperands(); 534 } 535 536 Operation::CloneOptions &Operation::CloneOptions::cloneRegions(bool enable) { 537 cloneRegionsFlag = enable; 538 return *this; 539 } 540 541 Operation::CloneOptions &Operation::CloneOptions::cloneOperands(bool enable) { 542 cloneOperandsFlag = enable; 543 return *this; 544 } 545 546 /// Create a deep copy of this operation but keep the operation regions empty. 547 /// Operands are remapped using `mapper` (if present), and `mapper` is updated 548 /// to contain the results. The `mapResults` flag specifies whether the results 549 /// of the cloned operation should be added to the map. 550 Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper) { 551 return clone(mapper, CloneOptions::all().cloneRegions(false)); 552 } 553 554 Operation *Operation::cloneWithoutRegions() { 555 BlockAndValueMapping mapper; 556 return cloneWithoutRegions(mapper); 557 } 558 559 /// Create a deep copy of this operation, remapping any operands that use 560 /// values outside of the operation using the map that is provided (leaving 561 /// them alone if no entry is present). Replaces references to cloned 562 /// sub-operations to the corresponding operation that is copied, and adds 563 /// those mappings to the map. 564 Operation *Operation::clone(BlockAndValueMapping &mapper, 565 CloneOptions options) { 566 SmallVector<Value, 8> operands; 567 SmallVector<Block *, 2> successors; 568 569 // Remap the operands. 570 if (options.shouldCloneOperands()) { 571 operands.reserve(getNumOperands()); 572 for (auto opValue : getOperands()) 573 operands.push_back(mapper.lookupOrDefault(opValue)); 574 } 575 576 // Remap the successors. 577 successors.reserve(getNumSuccessors()); 578 for (Block *successor : getSuccessors()) 579 successors.push_back(mapper.lookupOrDefault(successor)); 580 581 // Create the new operation. 582 auto *newOp = create(getLoc(), getName(), getResultTypes(), operands, attrs, 583 successors, getNumRegions()); 584 585 // Clone the regions. 586 if (options.shouldCloneRegions()) { 587 for (unsigned i = 0; i != numRegions; ++i) 588 getRegion(i).cloneInto(&newOp->getRegion(i), mapper); 589 } 590 591 // Remember the mapping of any results. 592 for (unsigned i = 0, e = getNumResults(); i != e; ++i) 593 mapper.map(getResult(i), newOp->getResult(i)); 594 595 return newOp; 596 } 597 598 Operation *Operation::clone(CloneOptions options) { 599 BlockAndValueMapping mapper; 600 return clone(mapper, options); 601 } 602 603 //===----------------------------------------------------------------------===// 604 // OpState trait class. 605 //===----------------------------------------------------------------------===// 606 607 // The fallback for the parser is to try for a dialect operation parser. 608 // Otherwise, reject the custom assembly form. 609 ParseResult OpState::parse(OpAsmParser &parser, OperationState &result) { 610 if (auto parseFn = result.name.getDialect()->getParseOperationHook( 611 result.name.getStringRef())) 612 return (*parseFn)(parser, result); 613 return parser.emitError(parser.getNameLoc(), "has no custom assembly form"); 614 } 615 616 // The fallback for the printer is to try for a dialect operation printer. 617 // Otherwise, it prints the generic form. 618 void OpState::print(Operation *op, OpAsmPrinter &p, StringRef defaultDialect) { 619 if (auto printFn = op->getDialect()->getOperationPrinter(op)) { 620 printOpName(op, p, defaultDialect); 621 printFn(op, p); 622 } else { 623 p.printGenericOp(op); 624 } 625 } 626 627 /// Print an operation name, eliding the dialect prefix if necessary. 628 void OpState::printOpName(Operation *op, OpAsmPrinter &p, 629 StringRef defaultDialect) { 630 StringRef name = op->getName().getStringRef(); 631 if (name.startswith((defaultDialect + ".").str())) 632 name = name.drop_front(defaultDialect.size() + 1); 633 // TODO: remove this special case (and update test/IR/parser.mlir) 634 else if ((defaultDialect.empty() || defaultDialect == "builtin") && 635 name.startswith("func.")) 636 name = name.drop_front(5); 637 p.getStream() << name; 638 } 639 640 /// Emit an error about fatal conditions with this operation, reporting up to 641 /// any diagnostic handlers that may be listening. 642 InFlightDiagnostic OpState::emitError(const Twine &message) { 643 return getOperation()->emitError(message); 644 } 645 646 /// Emit an error with the op name prefixed, like "'dim' op " which is 647 /// convenient for verifiers. 648 InFlightDiagnostic OpState::emitOpError(const Twine &message) { 649 return getOperation()->emitOpError(message); 650 } 651 652 /// Emit a warning about this operation, reporting up to any diagnostic 653 /// handlers that may be listening. 654 InFlightDiagnostic OpState::emitWarning(const Twine &message) { 655 return getOperation()->emitWarning(message); 656 } 657 658 /// Emit a remark about this operation, reporting up to any diagnostic 659 /// handlers that may be listening. 660 InFlightDiagnostic OpState::emitRemark(const Twine &message) { 661 return getOperation()->emitRemark(message); 662 } 663 664 //===----------------------------------------------------------------------===// 665 // Op Trait implementations 666 //===----------------------------------------------------------------------===// 667 668 OpFoldResult OpTrait::impl::foldIdempotent(Operation *op) { 669 if (op->getNumOperands() == 1) { 670 auto *argumentOp = op->getOperand(0).getDefiningOp(); 671 if (argumentOp && op->getName() == argumentOp->getName()) { 672 // Replace the outer operation output with the inner operation. 673 return op->getOperand(0); 674 } 675 } else if (op->getOperand(0) == op->getOperand(1)) { 676 return op->getOperand(0); 677 } 678 679 return {}; 680 } 681 682 OpFoldResult OpTrait::impl::foldInvolution(Operation *op) { 683 auto *argumentOp = op->getOperand(0).getDefiningOp(); 684 if (argumentOp && op->getName() == argumentOp->getName()) { 685 // Replace the outer involutions output with inner's input. 686 return argumentOp->getOperand(0); 687 } 688 689 return {}; 690 } 691 692 LogicalResult OpTrait::impl::verifyZeroOperands(Operation *op) { 693 if (op->getNumOperands() != 0) 694 return op->emitOpError() << "requires zero operands"; 695 return success(); 696 } 697 698 LogicalResult OpTrait::impl::verifyOneOperand(Operation *op) { 699 if (op->getNumOperands() != 1) 700 return op->emitOpError() << "requires a single operand"; 701 return success(); 702 } 703 704 LogicalResult OpTrait::impl::verifyNOperands(Operation *op, 705 unsigned numOperands) { 706 if (op->getNumOperands() != numOperands) { 707 return op->emitOpError() << "expected " << numOperands 708 << " operands, but found " << op->getNumOperands(); 709 } 710 return success(); 711 } 712 713 LogicalResult OpTrait::impl::verifyAtLeastNOperands(Operation *op, 714 unsigned numOperands) { 715 if (op->getNumOperands() < numOperands) 716 return op->emitOpError() 717 << "expected " << numOperands << " or more operands, but found " 718 << op->getNumOperands(); 719 return success(); 720 } 721 722 /// If this is a vector type, or a tensor type, return the scalar element type 723 /// that it is built around, otherwise return the type unmodified. 724 static Type getTensorOrVectorElementType(Type type) { 725 if (auto vec = type.dyn_cast<VectorType>()) 726 return vec.getElementType(); 727 728 // Look through tensor<vector<...>> to find the underlying element type. 729 if (auto tensor = type.dyn_cast<TensorType>()) 730 return getTensorOrVectorElementType(tensor.getElementType()); 731 return type; 732 } 733 734 LogicalResult OpTrait::impl::verifyIsIdempotent(Operation *op) { 735 // FIXME: Add back check for no side effects on operation. 736 // Currently adding it would cause the shared library build 737 // to fail since there would be a dependency of IR on SideEffectInterfaces 738 // which is cyclical. 739 return success(); 740 } 741 742 LogicalResult OpTrait::impl::verifyIsInvolution(Operation *op) { 743 // FIXME: Add back check for no side effects on operation. 744 // Currently adding it would cause the shared library build 745 // to fail since there would be a dependency of IR on SideEffectInterfaces 746 // which is cyclical. 747 return success(); 748 } 749 750 LogicalResult 751 OpTrait::impl::verifyOperandsAreSignlessIntegerLike(Operation *op) { 752 for (auto opType : op->getOperandTypes()) { 753 auto type = getTensorOrVectorElementType(opType); 754 if (!type.isSignlessIntOrIndex()) 755 return op->emitOpError() << "requires an integer or index type"; 756 } 757 return success(); 758 } 759 760 LogicalResult OpTrait::impl::verifyOperandsAreFloatLike(Operation *op) { 761 for (auto opType : op->getOperandTypes()) { 762 auto type = getTensorOrVectorElementType(opType); 763 if (!type.isa<FloatType>()) 764 return op->emitOpError("requires a float type"); 765 } 766 return success(); 767 } 768 769 LogicalResult OpTrait::impl::verifySameTypeOperands(Operation *op) { 770 // Zero or one operand always have the "same" type. 771 unsigned nOperands = op->getNumOperands(); 772 if (nOperands < 2) 773 return success(); 774 775 auto type = op->getOperand(0).getType(); 776 for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) 777 if (opType != type) 778 return op->emitOpError() << "requires all operands to have the same type"; 779 return success(); 780 } 781 782 LogicalResult OpTrait::impl::verifyZeroRegion(Operation *op) { 783 if (op->getNumRegions() != 0) 784 return op->emitOpError() << "requires zero regions"; 785 return success(); 786 } 787 788 LogicalResult OpTrait::impl::verifyOneRegion(Operation *op) { 789 if (op->getNumRegions() != 1) 790 return op->emitOpError() << "requires one region"; 791 return success(); 792 } 793 794 LogicalResult OpTrait::impl::verifyNRegions(Operation *op, 795 unsigned numRegions) { 796 if (op->getNumRegions() != numRegions) 797 return op->emitOpError() << "expected " << numRegions << " regions"; 798 return success(); 799 } 800 801 LogicalResult OpTrait::impl::verifyAtLeastNRegions(Operation *op, 802 unsigned numRegions) { 803 if (op->getNumRegions() < numRegions) 804 return op->emitOpError() << "expected " << numRegions << " or more regions"; 805 return success(); 806 } 807 808 LogicalResult OpTrait::impl::verifyZeroResult(Operation *op) { 809 if (op->getNumResults() != 0) 810 return op->emitOpError() << "requires zero results"; 811 return success(); 812 } 813 814 LogicalResult OpTrait::impl::verifyOneResult(Operation *op) { 815 if (op->getNumResults() != 1) 816 return op->emitOpError() << "requires one result"; 817 return success(); 818 } 819 820 LogicalResult OpTrait::impl::verifyNResults(Operation *op, 821 unsigned numOperands) { 822 if (op->getNumResults() != numOperands) 823 return op->emitOpError() << "expected " << numOperands << " results"; 824 return success(); 825 } 826 827 LogicalResult OpTrait::impl::verifyAtLeastNResults(Operation *op, 828 unsigned numOperands) { 829 if (op->getNumResults() < numOperands) 830 return op->emitOpError() 831 << "expected " << numOperands << " or more results"; 832 return success(); 833 } 834 835 LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) { 836 if (failed(verifyAtLeastNOperands(op, 1))) 837 return failure(); 838 839 if (failed(verifyCompatibleShapes(op->getOperandTypes()))) 840 return op->emitOpError() << "requires the same shape for all operands"; 841 842 return success(); 843 } 844 845 LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) { 846 if (failed(verifyAtLeastNOperands(op, 1)) || 847 failed(verifyAtLeastNResults(op, 1))) 848 return failure(); 849 850 SmallVector<Type, 8> types(op->getOperandTypes()); 851 types.append(llvm::to_vector<4>(op->getResultTypes())); 852 853 if (failed(verifyCompatibleShapes(types))) 854 return op->emitOpError() 855 << "requires the same shape for all operands and results"; 856 857 return success(); 858 } 859 860 LogicalResult OpTrait::impl::verifySameOperandsElementType(Operation *op) { 861 if (failed(verifyAtLeastNOperands(op, 1))) 862 return failure(); 863 auto elementType = getElementTypeOrSelf(op->getOperand(0)); 864 865 for (auto operand : llvm::drop_begin(op->getOperands(), 1)) { 866 if (getElementTypeOrSelf(operand) != elementType) 867 return op->emitOpError("requires the same element type for all operands"); 868 } 869 870 return success(); 871 } 872 873 LogicalResult 874 OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) { 875 if (failed(verifyAtLeastNOperands(op, 1)) || 876 failed(verifyAtLeastNResults(op, 1))) 877 return failure(); 878 879 auto elementType = getElementTypeOrSelf(op->getResult(0)); 880 881 // Verify result element type matches first result's element type. 882 for (auto result : llvm::drop_begin(op->getResults(), 1)) { 883 if (getElementTypeOrSelf(result) != elementType) 884 return op->emitOpError( 885 "requires the same element type for all operands and results"); 886 } 887 888 // Verify operand's element type matches first result's element type. 889 for (auto operand : op->getOperands()) { 890 if (getElementTypeOrSelf(operand) != elementType) 891 return op->emitOpError( 892 "requires the same element type for all operands and results"); 893 } 894 895 return success(); 896 } 897 898 LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) { 899 if (failed(verifyAtLeastNOperands(op, 1)) || 900 failed(verifyAtLeastNResults(op, 1))) 901 return failure(); 902 903 auto type = op->getResult(0).getType(); 904 auto elementType = getElementTypeOrSelf(type); 905 for (auto resultType : llvm::drop_begin(op->getResultTypes())) { 906 if (getElementTypeOrSelf(resultType) != elementType || 907 failed(verifyCompatibleShape(resultType, type))) 908 return op->emitOpError() 909 << "requires the same type for all operands and results"; 910 } 911 for (auto opType : op->getOperandTypes()) { 912 if (getElementTypeOrSelf(opType) != elementType || 913 failed(verifyCompatibleShape(opType, type))) 914 return op->emitOpError() 915 << "requires the same type for all operands and results"; 916 } 917 return success(); 918 } 919 920 LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) { 921 Block *block = op->getBlock(); 922 // Verify that the operation is at the end of the respective parent block. 923 if (!block || &block->back() != op) 924 return op->emitOpError("must be the last operation in the parent block"); 925 return success(); 926 } 927 928 static LogicalResult verifyTerminatorSuccessors(Operation *op) { 929 auto *parent = op->getParentRegion(); 930 931 // Verify that the operands lines up with the BB arguments in the successor. 932 for (Block *succ : op->getSuccessors()) 933 if (succ->getParent() != parent) 934 return op->emitError("reference to block defined in another region"); 935 return success(); 936 } 937 938 LogicalResult OpTrait::impl::verifyZeroSuccessor(Operation *op) { 939 if (op->getNumSuccessors() != 0) { 940 return op->emitOpError("requires 0 successors but found ") 941 << op->getNumSuccessors(); 942 } 943 return success(); 944 } 945 946 LogicalResult OpTrait::impl::verifyOneSuccessor(Operation *op) { 947 if (op->getNumSuccessors() != 1) { 948 return op->emitOpError("requires 1 successor but found ") 949 << op->getNumSuccessors(); 950 } 951 return verifyTerminatorSuccessors(op); 952 } 953 LogicalResult OpTrait::impl::verifyNSuccessors(Operation *op, 954 unsigned numSuccessors) { 955 if (op->getNumSuccessors() != numSuccessors) { 956 return op->emitOpError("requires ") 957 << numSuccessors << " successors but found " 958 << op->getNumSuccessors(); 959 } 960 return verifyTerminatorSuccessors(op); 961 } 962 LogicalResult OpTrait::impl::verifyAtLeastNSuccessors(Operation *op, 963 unsigned numSuccessors) { 964 if (op->getNumSuccessors() < numSuccessors) { 965 return op->emitOpError("requires at least ") 966 << numSuccessors << " successors but found " 967 << op->getNumSuccessors(); 968 } 969 return verifyTerminatorSuccessors(op); 970 } 971 972 LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) { 973 for (auto resultType : op->getResultTypes()) { 974 auto elementType = getTensorOrVectorElementType(resultType); 975 bool isBoolType = elementType.isInteger(1); 976 if (!isBoolType) 977 return op->emitOpError() << "requires a bool result type"; 978 } 979 980 return success(); 981 } 982 983 LogicalResult OpTrait::impl::verifyResultsAreFloatLike(Operation *op) { 984 for (auto resultType : op->getResultTypes()) 985 if (!getTensorOrVectorElementType(resultType).isa<FloatType>()) 986 return op->emitOpError() << "requires a floating point type"; 987 988 return success(); 989 } 990 991 LogicalResult 992 OpTrait::impl::verifyResultsAreSignlessIntegerLike(Operation *op) { 993 for (auto resultType : op->getResultTypes()) 994 if (!getTensorOrVectorElementType(resultType).isSignlessIntOrIndex()) 995 return op->emitOpError() << "requires an integer or index type"; 996 return success(); 997 } 998 999 LogicalResult OpTrait::impl::verifyValueSizeAttr(Operation *op, 1000 StringRef attrName, 1001 StringRef valueGroupName, 1002 size_t expectedCount) { 1003 auto sizeAttr = op->getAttrOfType<DenseIntElementsAttr>(attrName); 1004 if (!sizeAttr) 1005 return op->emitOpError("requires 1D i32 elements attribute '") 1006 << attrName << "'"; 1007 1008 auto sizeAttrType = sizeAttr.getType(); 1009 if (sizeAttrType.getRank() != 1 || 1010 !sizeAttrType.getElementType().isInteger(32)) 1011 return op->emitOpError("requires 1D i32 elements attribute '") 1012 << attrName << "'"; 1013 1014 if (llvm::any_of(sizeAttr.getValues<APInt>(), [](const APInt &element) { 1015 return !element.isNonNegative(); 1016 })) 1017 return op->emitOpError("'") 1018 << attrName << "' attribute cannot have negative elements"; 1019 1020 size_t totalCount = std::accumulate( 1021 sizeAttr.begin(), sizeAttr.end(), 0, 1022 [](unsigned all, const APInt &one) { return all + one.getZExtValue(); }); 1023 1024 if (totalCount != expectedCount) 1025 return op->emitOpError() 1026 << valueGroupName << " count (" << expectedCount 1027 << ") does not match with the total size (" << totalCount 1028 << ") specified in attribute '" << attrName << "'"; 1029 return success(); 1030 } 1031 1032 LogicalResult OpTrait::impl::verifyOperandSizeAttr(Operation *op, 1033 StringRef attrName) { 1034 return verifyValueSizeAttr(op, attrName, "operand", op->getNumOperands()); 1035 } 1036 1037 LogicalResult OpTrait::impl::verifyResultSizeAttr(Operation *op, 1038 StringRef attrName) { 1039 return verifyValueSizeAttr(op, attrName, "result", op->getNumResults()); 1040 } 1041 1042 LogicalResult OpTrait::impl::verifyNoRegionArguments(Operation *op) { 1043 for (Region ®ion : op->getRegions()) { 1044 if (region.empty()) 1045 continue; 1046 1047 if (region.getNumArguments() != 0) { 1048 if (op->getNumRegions() > 1) 1049 return op->emitOpError("region #") 1050 << region.getRegionNumber() << " should have no arguments"; 1051 return op->emitOpError("region should have no arguments"); 1052 } 1053 } 1054 return success(); 1055 } 1056 1057 LogicalResult OpTrait::impl::verifyElementwise(Operation *op) { 1058 auto isMappableType = [](Type type) { 1059 return type.isa<VectorType, TensorType>(); 1060 }; 1061 auto resultMappableTypes = llvm::to_vector<1>( 1062 llvm::make_filter_range(op->getResultTypes(), isMappableType)); 1063 auto operandMappableTypes = llvm::to_vector<2>( 1064 llvm::make_filter_range(op->getOperandTypes(), isMappableType)); 1065 1066 // If the op only has scalar operand/result types, then we have nothing to 1067 // check. 1068 if (resultMappableTypes.empty() && operandMappableTypes.empty()) 1069 return success(); 1070 1071 if (!resultMappableTypes.empty() && operandMappableTypes.empty()) 1072 return op->emitOpError("if a result is non-scalar, then at least one " 1073 "operand must be non-scalar"); 1074 1075 assert(!operandMappableTypes.empty()); 1076 1077 if (resultMappableTypes.empty()) 1078 return op->emitOpError("if an operand is non-scalar, then there must be at " 1079 "least one non-scalar result"); 1080 1081 if (resultMappableTypes.size() != op->getNumResults()) 1082 return op->emitOpError( 1083 "if an operand is non-scalar, then all results must be non-scalar"); 1084 1085 SmallVector<Type, 4> types = llvm::to_vector<2>( 1086 llvm::concat<Type>(operandMappableTypes, resultMappableTypes)); 1087 TypeID expectedBaseTy = types.front().getTypeID(); 1088 if (!llvm::all_of(types, 1089 [&](Type t) { return t.getTypeID() == expectedBaseTy; }) || 1090 failed(verifyCompatibleShapes(types))) { 1091 return op->emitOpError() << "all non-scalar operands/results must have the " 1092 "same shape and base type"; 1093 } 1094 1095 return success(); 1096 } 1097 1098 /// Check for any values used by operations regions attached to the 1099 /// specified "IsIsolatedFromAbove" operation defined outside of it. 1100 LogicalResult OpTrait::impl::verifyIsIsolatedFromAbove(Operation *isolatedOp) { 1101 assert(isolatedOp->hasTrait<OpTrait::IsIsolatedFromAbove>() && 1102 "Intended to check IsolatedFromAbove ops"); 1103 1104 // List of regions to analyze. Each region is processed independently, with 1105 // respect to the common `limit` region, so we can look at them in any order. 1106 // Therefore, use a simple vector and push/pop back the current region. 1107 SmallVector<Region *, 8> pendingRegions; 1108 for (auto ®ion : isolatedOp->getRegions()) { 1109 pendingRegions.push_back(®ion); 1110 1111 // Traverse all operations in the region. 1112 while (!pendingRegions.empty()) { 1113 for (Operation &op : pendingRegions.pop_back_val()->getOps()) { 1114 for (Value operand : op.getOperands()) { 1115 // Check that any value that is used by an operation is defined in the 1116 // same region as either an operation result. 1117 auto *operandRegion = operand.getParentRegion(); 1118 if (!operandRegion) 1119 return op.emitError("operation's operand is unlinked"); 1120 if (!region.isAncestor(operandRegion)) { 1121 return op.emitOpError("using value defined outside the region") 1122 .attachNote(isolatedOp->getLoc()) 1123 << "required by region isolation constraints"; 1124 } 1125 } 1126 1127 // Schedule any regions in the operation for further checking. Don't 1128 // recurse into other IsolatedFromAbove ops, because they will check 1129 // themselves. 1130 if (op.getNumRegions() && 1131 !op.hasTrait<OpTrait::IsIsolatedFromAbove>()) { 1132 for (Region &subRegion : op.getRegions()) 1133 pendingRegions.push_back(&subRegion); 1134 } 1135 } 1136 } 1137 } 1138 1139 return success(); 1140 } 1141 1142 bool OpTrait::hasElementwiseMappableTraits(Operation *op) { 1143 return op->hasTrait<Elementwise>() && op->hasTrait<Scalarizable>() && 1144 op->hasTrait<Vectorizable>() && op->hasTrait<Tensorizable>(); 1145 } 1146 1147 //===----------------------------------------------------------------------===// 1148 // CastOpInterface 1149 //===----------------------------------------------------------------------===// 1150 1151 /// Attempt to fold the given cast operation. 1152 LogicalResult 1153 impl::foldCastInterfaceOp(Operation *op, ArrayRef<Attribute> attrOperands, 1154 SmallVectorImpl<OpFoldResult> &foldResults) { 1155 OperandRange operands = op->getOperands(); 1156 if (operands.empty()) 1157 return failure(); 1158 ResultRange results = op->getResults(); 1159 1160 // Check for the case where the input and output types match 1-1. 1161 if (operands.getTypes() == results.getTypes()) { 1162 foldResults.append(operands.begin(), operands.end()); 1163 return success(); 1164 } 1165 1166 return failure(); 1167 } 1168 1169 /// Attempt to verify the given cast operation. 1170 LogicalResult impl::verifyCastInterfaceOp( 1171 Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible) { 1172 auto resultTypes = op->getResultTypes(); 1173 if (llvm::empty(resultTypes)) 1174 return op->emitOpError() 1175 << "expected at least one result for cast operation"; 1176 1177 auto operandTypes = op->getOperandTypes(); 1178 if (!areCastCompatible(operandTypes, resultTypes)) { 1179 InFlightDiagnostic diag = op->emitOpError("operand type"); 1180 if (llvm::empty(operandTypes)) 1181 diag << "s []"; 1182 else if (llvm::size(operandTypes) == 1) 1183 diag << " " << *operandTypes.begin(); 1184 else 1185 diag << "s " << operandTypes; 1186 return diag << " and result type" << (resultTypes.size() == 1 ? " " : "s ") 1187 << resultTypes << " are cast incompatible"; 1188 } 1189 1190 return success(); 1191 } 1192 1193 //===----------------------------------------------------------------------===// 1194 // Misc. utils 1195 //===----------------------------------------------------------------------===// 1196 1197 /// Insert an operation, generated by `buildTerminatorOp`, at the end of the 1198 /// region's only block if it does not have a terminator already. If the region 1199 /// is empty, insert a new block first. `buildTerminatorOp` should return the 1200 /// terminator operation to insert. 1201 void impl::ensureRegionTerminator( 1202 Region ®ion, OpBuilder &builder, Location loc, 1203 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { 1204 OpBuilder::InsertionGuard guard(builder); 1205 if (region.empty()) 1206 builder.createBlock(®ion); 1207 1208 Block &block = region.back(); 1209 if (!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>()) 1210 return; 1211 1212 builder.setInsertionPointToEnd(&block); 1213 builder.insert(buildTerminatorOp(builder, loc)); 1214 } 1215 1216 /// Create a simple OpBuilder and forward to the OpBuilder version of this 1217 /// function. 1218 void impl::ensureRegionTerminator( 1219 Region ®ion, Builder &builder, Location loc, 1220 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { 1221 OpBuilder opBuilder(builder.getContext()); 1222 ensureRegionTerminator(region, opBuilder, loc, buildTerminatorOp); 1223 } 1224