1 //===- Operation.cpp - Operation support code -----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Operation.h" 10 #include "mlir/IR/BlockAndValueMapping.h" 11 #include "mlir/IR/BuiltinTypes.h" 12 #include "mlir/IR/Dialect.h" 13 #include "mlir/IR/OpImplementation.h" 14 #include "mlir/IR/PatternMatch.h" 15 #include "mlir/IR/TypeUtilities.h" 16 #include "mlir/Interfaces/FoldInterfaces.h" 17 #include "llvm/ADT/StringExtras.h" 18 #include <numeric> 19 20 using namespace mlir; 21 22 //===----------------------------------------------------------------------===// 23 // Operation 24 //===----------------------------------------------------------------------===// 25 26 /// Create a new Operation with the specific fields. 27 Operation *Operation::create(Location location, OperationName name, 28 TypeRange resultTypes, ValueRange operands, 29 ArrayRef<NamedAttribute> attributes, 30 BlockRange successors, unsigned numRegions) { 31 return create(location, name, resultTypes, operands, 32 DictionaryAttr::get(location.getContext(), attributes), 33 successors, numRegions); 34 } 35 36 /// Create a new Operation from operation state. 37 Operation *Operation::create(const OperationState &state) { 38 return create(state.location, state.name, state.types, state.operands, 39 state.attributes.getDictionary(state.getContext()), 40 state.successors, state.regions); 41 } 42 43 /// Create a new Operation with the specific fields. 44 Operation *Operation::create(Location location, OperationName name, 45 TypeRange resultTypes, ValueRange operands, 46 DictionaryAttr attributes, BlockRange successors, 47 RegionRange regions) { 48 unsigned numRegions = regions.size(); 49 Operation *op = create(location, name, resultTypes, operands, attributes, 50 successors, numRegions); 51 for (unsigned i = 0; i < numRegions; ++i) 52 if (regions[i]) 53 op->getRegion(i).takeBody(*regions[i]); 54 return op; 55 } 56 57 /// Overload of create that takes an existing DictionaryAttr to avoid 58 /// unnecessarily uniquing a list of attributes. 59 Operation *Operation::create(Location location, OperationName name, 60 TypeRange resultTypes, ValueRange operands, 61 DictionaryAttr attributes, BlockRange successors, 62 unsigned numRegions) { 63 assert(llvm::all_of(resultTypes, [](Type t) { return t; }) && 64 "unexpected null result type"); 65 66 // We only need to allocate additional memory for a subset of results. 67 unsigned numTrailingResults = OpResult::getNumTrailing(resultTypes.size()); 68 unsigned numInlineResults = OpResult::getNumInline(resultTypes.size()); 69 unsigned numSuccessors = successors.size(); 70 unsigned numOperands = operands.size(); 71 unsigned numResults = resultTypes.size(); 72 73 // If the operation is known to have no operands, don't allocate an operand 74 // storage. 75 bool needsOperandStorage = 76 operands.empty() ? !name.hasTrait<OpTrait::ZeroOperands>() : true; 77 78 // Compute the byte size for the operation and the operand storage. This takes 79 // into account the size of the operation, its trailing objects, and its 80 // prefixed objects. 81 size_t byteSize = 82 totalSizeToAlloc<detail::OperandStorage, BlockOperand, Region, OpOperand>( 83 needsOperandStorage ? 1 : 0, numSuccessors, numRegions, numOperands); 84 size_t prefixByteSize = llvm::alignTo( 85 Operation::prefixAllocSize(numTrailingResults, numInlineResults), 86 alignof(Operation)); 87 char *mallocMem = reinterpret_cast<char *>(malloc(byteSize + prefixByteSize)); 88 void *rawMem = mallocMem + prefixByteSize; 89 90 // Create the new Operation. 91 Operation *op = 92 ::new (rawMem) Operation(location, name, numResults, numSuccessors, 93 numRegions, attributes, needsOperandStorage); 94 95 assert((numSuccessors == 0 || op->mightHaveTrait<OpTrait::IsTerminator>()) && 96 "unexpected successors in a non-terminator operation"); 97 98 // Initialize the results. 99 auto resultTypeIt = resultTypes.begin(); 100 for (unsigned i = 0; i < numInlineResults; ++i, ++resultTypeIt) 101 new (op->getInlineOpResult(i)) detail::InlineOpResult(*resultTypeIt, i); 102 for (unsigned i = 0; i < numTrailingResults; ++i, ++resultTypeIt) { 103 new (op->getOutOfLineOpResult(i)) 104 detail::OutOfLineOpResult(*resultTypeIt, i); 105 } 106 107 // Initialize the regions. 108 for (unsigned i = 0; i != numRegions; ++i) 109 new (&op->getRegion(i)) Region(op); 110 111 // Initialize the operands. 112 if (needsOperandStorage) { 113 new (&op->getOperandStorage()) detail::OperandStorage( 114 op, op->getTrailingObjects<OpOperand>(), operands); 115 } 116 117 // Initialize the successors. 118 auto blockOperands = op->getBlockOperands(); 119 for (unsigned i = 0; i != numSuccessors; ++i) 120 new (&blockOperands[i]) BlockOperand(op, successors[i]); 121 122 return op; 123 } 124 125 Operation::Operation(Location location, OperationName name, unsigned numResults, 126 unsigned numSuccessors, unsigned numRegions, 127 DictionaryAttr attributes, bool hasOperandStorage) 128 : location(location), numResults(numResults), numSuccs(numSuccessors), 129 numRegions(numRegions), hasOperandStorage(hasOperandStorage), name(name), 130 attrs(attributes) { 131 assert(attributes && "unexpected null attribute dictionary"); 132 #ifndef NDEBUG 133 if (!getDialect() && !getContext()->allowsUnregisteredDialects()) 134 llvm::report_fatal_error( 135 name.getStringRef() + 136 " created with unregistered dialect. If this is intended, please call " 137 "allowUnregisteredDialects() on the MLIRContext, or use " 138 "-allow-unregistered-dialect with the MLIR tool used."); 139 #endif 140 } 141 142 // Operations are deleted through the destroy() member because they are 143 // allocated via malloc. 144 Operation::~Operation() { 145 assert(block == nullptr && "operation destroyed but still in a block"); 146 #ifndef NDEBUG 147 if (!use_empty()) { 148 { 149 InFlightDiagnostic diag = 150 emitOpError("operation destroyed but still has uses"); 151 for (Operation *user : getUsers()) 152 diag.attachNote(user->getLoc()) << "- use: " << *user << "\n"; 153 } 154 llvm::report_fatal_error("operation destroyed but still has uses"); 155 } 156 #endif 157 // Explicitly run the destructors for the operands. 158 if (hasOperandStorage) 159 getOperandStorage().~OperandStorage(); 160 161 // Explicitly run the destructors for the successors. 162 for (auto &successor : getBlockOperands()) 163 successor.~BlockOperand(); 164 165 // Explicitly destroy the regions. 166 for (auto ®ion : getRegions()) 167 region.~Region(); 168 } 169 170 /// Destroy this operation or one of its subclasses. 171 void Operation::destroy() { 172 // Operations may have additional prefixed allocation, which needs to be 173 // accounted for here when computing the address to free. 174 char *rawMem = reinterpret_cast<char *>(this) - 175 llvm::alignTo(prefixAllocSize(), alignof(Operation)); 176 this->~Operation(); 177 free(rawMem); 178 } 179 180 /// Return true if this operation is a proper ancestor of the `other` 181 /// operation. 182 bool Operation::isProperAncestor(Operation *other) { 183 while ((other = other->getParentOp())) 184 if (this == other) 185 return true; 186 return false; 187 } 188 189 /// Replace any uses of 'from' with 'to' within this operation. 190 void Operation::replaceUsesOfWith(Value from, Value to) { 191 if (from == to) 192 return; 193 for (auto &operand : getOpOperands()) 194 if (operand.get() == from) 195 operand.set(to); 196 } 197 198 /// Replace the current operands of this operation with the ones provided in 199 /// 'operands'. 200 void Operation::setOperands(ValueRange operands) { 201 if (LLVM_LIKELY(hasOperandStorage)) 202 return getOperandStorage().setOperands(this, operands); 203 assert(operands.empty() && "setting operands without an operand storage"); 204 } 205 206 /// Replace the operands beginning at 'start' and ending at 'start' + 'length' 207 /// with the ones provided in 'operands'. 'operands' may be smaller or larger 208 /// than the range pointed to by 'start'+'length'. 209 void Operation::setOperands(unsigned start, unsigned length, 210 ValueRange operands) { 211 assert((start + length) <= getNumOperands() && 212 "invalid operand range specified"); 213 if (LLVM_LIKELY(hasOperandStorage)) 214 return getOperandStorage().setOperands(this, start, length, operands); 215 assert(operands.empty() && "setting operands without an operand storage"); 216 } 217 218 /// Insert the given operands into the operand list at the given 'index'. 219 void Operation::insertOperands(unsigned index, ValueRange operands) { 220 if (LLVM_LIKELY(hasOperandStorage)) 221 return setOperands(index, /*length=*/0, operands); 222 assert(operands.empty() && "inserting operands without an operand storage"); 223 } 224 225 //===----------------------------------------------------------------------===// 226 // Diagnostics 227 //===----------------------------------------------------------------------===// 228 229 /// Emit an error about fatal conditions with this operation, reporting up to 230 /// any diagnostic handlers that may be listening. 231 InFlightDiagnostic Operation::emitError(const Twine &message) { 232 InFlightDiagnostic diag = mlir::emitError(getLoc(), message); 233 if (getContext()->shouldPrintOpOnDiagnostic()) { 234 diag.attachNote(getLoc()) 235 .append("see current operation: ") 236 .appendOp(*this, OpPrintingFlags().printGenericOpForm()); 237 } 238 return diag; 239 } 240 241 /// Emit a warning about this operation, reporting up to any diagnostic 242 /// handlers that may be listening. 243 InFlightDiagnostic Operation::emitWarning(const Twine &message) { 244 InFlightDiagnostic diag = mlir::emitWarning(getLoc(), message); 245 if (getContext()->shouldPrintOpOnDiagnostic()) 246 diag.attachNote(getLoc()) << "see current operation: " << *this; 247 return diag; 248 } 249 250 /// Emit a remark about this operation, reporting up to any diagnostic 251 /// handlers that may be listening. 252 InFlightDiagnostic Operation::emitRemark(const Twine &message) { 253 InFlightDiagnostic diag = mlir::emitRemark(getLoc(), message); 254 if (getContext()->shouldPrintOpOnDiagnostic()) 255 diag.attachNote(getLoc()) << "see current operation: " << *this; 256 return diag; 257 } 258 259 //===----------------------------------------------------------------------===// 260 // Operation Ordering 261 //===----------------------------------------------------------------------===// 262 263 constexpr unsigned Operation::kInvalidOrderIdx; 264 constexpr unsigned Operation::kOrderStride; 265 266 /// Given an operation 'other' that is within the same parent block, return 267 /// whether the current operation is before 'other' in the operation list 268 /// of the parent block. 269 /// Note: This function has an average complexity of O(1), but worst case may 270 /// take O(N) where N is the number of operations within the parent block. 271 bool Operation::isBeforeInBlock(Operation *other) { 272 assert(block && "Operations without parent blocks have no order."); 273 assert(other && other->block == block && 274 "Expected other operation to have the same parent block."); 275 // If the order of the block is already invalid, directly recompute the 276 // parent. 277 if (!block->isOpOrderValid()) { 278 block->recomputeOpOrder(); 279 } else { 280 // Update the order either operation if necessary. 281 updateOrderIfNecessary(); 282 other->updateOrderIfNecessary(); 283 } 284 285 return orderIndex < other->orderIndex; 286 } 287 288 /// Update the order index of this operation of this operation if necessary, 289 /// potentially recomputing the order of the parent block. 290 void Operation::updateOrderIfNecessary() { 291 assert(block && "expected valid parent"); 292 293 // If the order is valid for this operation there is nothing to do. 294 if (hasValidOrder()) 295 return; 296 Operation *blockFront = &block->front(); 297 Operation *blockBack = &block->back(); 298 299 // This method is expected to only be invoked on blocks with more than one 300 // operation. 301 assert(blockFront != blockBack && "expected more than one operation"); 302 303 // If the operation is at the end of the block. 304 if (this == blockBack) { 305 Operation *prevNode = getPrevNode(); 306 if (!prevNode->hasValidOrder()) 307 return block->recomputeOpOrder(); 308 309 // Add the stride to the previous operation. 310 orderIndex = prevNode->orderIndex + kOrderStride; 311 return; 312 } 313 314 // If this is the first operation try to use the next operation to compute the 315 // ordering. 316 if (this == blockFront) { 317 Operation *nextNode = getNextNode(); 318 if (!nextNode->hasValidOrder()) 319 return block->recomputeOpOrder(); 320 // There is no order to give this operation. 321 if (nextNode->orderIndex == 0) 322 return block->recomputeOpOrder(); 323 324 // If we can't use the stride, just take the middle value left. This is safe 325 // because we know there is at least one valid index to assign to. 326 if (nextNode->orderIndex <= kOrderStride) 327 orderIndex = (nextNode->orderIndex / 2); 328 else 329 orderIndex = kOrderStride; 330 return; 331 } 332 333 // Otherwise, this operation is between two others. Place this operation in 334 // the middle of the previous and next if possible. 335 Operation *prevNode = getPrevNode(), *nextNode = getNextNode(); 336 if (!prevNode->hasValidOrder() || !nextNode->hasValidOrder()) 337 return block->recomputeOpOrder(); 338 unsigned prevOrder = prevNode->orderIndex, nextOrder = nextNode->orderIndex; 339 340 // Check to see if there is a valid order between the two. 341 if (prevOrder + 1 == nextOrder) 342 return block->recomputeOpOrder(); 343 orderIndex = prevOrder + ((nextOrder - prevOrder) / 2); 344 } 345 346 //===----------------------------------------------------------------------===// 347 // ilist_traits for Operation 348 //===----------------------------------------------------------------------===// 349 350 auto llvm::ilist_detail::SpecificNodeAccess< 351 typename llvm::ilist_detail::compute_node_options< 352 ::mlir::Operation>::type>::getNodePtr(pointer n) -> node_type * { 353 return NodeAccess::getNodePtr<OptionsT>(n); 354 } 355 356 auto llvm::ilist_detail::SpecificNodeAccess< 357 typename llvm::ilist_detail::compute_node_options< 358 ::mlir::Operation>::type>::getNodePtr(const_pointer n) 359 -> const node_type * { 360 return NodeAccess::getNodePtr<OptionsT>(n); 361 } 362 363 auto llvm::ilist_detail::SpecificNodeAccess< 364 typename llvm::ilist_detail::compute_node_options< 365 ::mlir::Operation>::type>::getValuePtr(node_type *n) -> pointer { 366 return NodeAccess::getValuePtr<OptionsT>(n); 367 } 368 369 auto llvm::ilist_detail::SpecificNodeAccess< 370 typename llvm::ilist_detail::compute_node_options< 371 ::mlir::Operation>::type>::getValuePtr(const node_type *n) 372 -> const_pointer { 373 return NodeAccess::getValuePtr<OptionsT>(n); 374 } 375 376 void llvm::ilist_traits<::mlir::Operation>::deleteNode(Operation *op) { 377 op->destroy(); 378 } 379 380 Block *llvm::ilist_traits<::mlir::Operation>::getContainingBlock() { 381 size_t offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr)))); 382 iplist<Operation> *anchor(static_cast<iplist<Operation> *>(this)); 383 return reinterpret_cast<Block *>(reinterpret_cast<char *>(anchor) - offset); 384 } 385 386 /// This is a trait method invoked when an operation is added to a block. We 387 /// keep the block pointer up to date. 388 void llvm::ilist_traits<::mlir::Operation>::addNodeToList(Operation *op) { 389 assert(!op->getBlock() && "already in an operation block!"); 390 op->block = getContainingBlock(); 391 392 // Invalidate the order on the operation. 393 op->orderIndex = Operation::kInvalidOrderIdx; 394 } 395 396 /// This is a trait method invoked when an operation is removed from a block. 397 /// We keep the block pointer up to date. 398 void llvm::ilist_traits<::mlir::Operation>::removeNodeFromList(Operation *op) { 399 assert(op->block && "not already in an operation block!"); 400 op->block = nullptr; 401 } 402 403 /// This is a trait method invoked when an operation is moved from one block 404 /// to another. We keep the block pointer up to date. 405 void llvm::ilist_traits<::mlir::Operation>::transferNodesFromList( 406 ilist_traits<Operation> &otherList, op_iterator first, op_iterator last) { 407 Block *curParent = getContainingBlock(); 408 409 // Invalidate the ordering of the parent block. 410 curParent->invalidateOpOrder(); 411 412 // If we are transferring operations within the same block, the block 413 // pointer doesn't need to be updated. 414 if (curParent == otherList.getContainingBlock()) 415 return; 416 417 // Update the 'block' member of each operation. 418 for (; first != last; ++first) 419 first->block = curParent; 420 } 421 422 /// Remove this operation (and its descendants) from its Block and delete 423 /// all of them. 424 void Operation::erase() { 425 if (auto *parent = getBlock()) 426 parent->getOperations().erase(this); 427 else 428 destroy(); 429 } 430 431 /// Remove the operation from its parent block, but don't delete it. 432 void Operation::remove() { 433 if (Block *parent = getBlock()) 434 parent->getOperations().remove(this); 435 } 436 437 /// Unlink this operation from its current block and insert it right before 438 /// `existingOp` which may be in the same or another block in the same 439 /// function. 440 void Operation::moveBefore(Operation *existingOp) { 441 moveBefore(existingOp->getBlock(), existingOp->getIterator()); 442 } 443 444 /// Unlink this operation from its current basic block and insert it right 445 /// before `iterator` in the specified basic block. 446 void Operation::moveBefore(Block *block, 447 llvm::iplist<Operation>::iterator iterator) { 448 block->getOperations().splice(iterator, getBlock()->getOperations(), 449 getIterator()); 450 } 451 452 /// Unlink this operation from its current block and insert it right after 453 /// `existingOp` which may be in the same or another block in the same function. 454 void Operation::moveAfter(Operation *existingOp) { 455 moveAfter(existingOp->getBlock(), existingOp->getIterator()); 456 } 457 458 /// Unlink this operation from its current block and insert it right after 459 /// `iterator` in the specified block. 460 void Operation::moveAfter(Block *block, 461 llvm::iplist<Operation>::iterator iterator) { 462 assert(iterator != block->end() && "cannot move after end of block"); 463 moveBefore(block, std::next(iterator)); 464 } 465 466 /// This drops all operand uses from this operation, which is an essential 467 /// step in breaking cyclic dependences between references when they are to 468 /// be deleted. 469 void Operation::dropAllReferences() { 470 for (auto &op : getOpOperands()) 471 op.drop(); 472 473 for (auto ®ion : getRegions()) 474 region.dropAllReferences(); 475 476 for (auto &dest : getBlockOperands()) 477 dest.drop(); 478 } 479 480 /// This drops all uses of any values defined by this operation or its nested 481 /// regions, wherever they are located. 482 void Operation::dropAllDefinedValueUses() { 483 dropAllUses(); 484 485 for (auto ®ion : getRegions()) 486 for (auto &block : region) 487 block.dropAllDefinedValueUses(); 488 } 489 490 void Operation::setSuccessor(Block *block, unsigned index) { 491 assert(index < getNumSuccessors()); 492 getBlockOperands()[index].set(block); 493 } 494 495 /// Attempt to fold this operation using the Op's registered foldHook. 496 LogicalResult Operation::fold(ArrayRef<Attribute> operands, 497 SmallVectorImpl<OpFoldResult> &results) { 498 // If we have a registered operation definition matching this one, use it to 499 // try to constant fold the operation. 500 Optional<RegisteredOperationName> info = getRegisteredInfo(); 501 if (info && succeeded(info->foldHook(this, operands, results))) 502 return success(); 503 504 // Otherwise, fall back on the dialect hook to handle it. 505 Dialect *dialect = getDialect(); 506 if (!dialect) 507 return failure(); 508 509 auto *interface = dyn_cast<DialectFoldInterface>(dialect); 510 if (!interface) 511 return failure(); 512 513 return interface->fold(this, operands, results); 514 } 515 516 /// Emit an error with the op name prefixed, like "'dim' op " which is 517 /// convenient for verifiers. 518 InFlightDiagnostic Operation::emitOpError(const Twine &message) { 519 return emitError() << "'" << getName() << "' op " << message; 520 } 521 522 //===----------------------------------------------------------------------===// 523 // Operation Cloning 524 //===----------------------------------------------------------------------===// 525 526 /// Create a deep copy of this operation but keep the operation regions empty. 527 /// Operands are remapped using `mapper` (if present), and `mapper` is updated 528 /// to contain the results. 529 Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper) { 530 SmallVector<Value, 8> operands; 531 SmallVector<Block *, 2> successors; 532 533 // Remap the operands. 534 operands.reserve(getNumOperands()); 535 for (auto opValue : getOperands()) 536 operands.push_back(mapper.lookupOrDefault(opValue)); 537 538 // Remap the successors. 539 successors.reserve(getNumSuccessors()); 540 for (Block *successor : getSuccessors()) 541 successors.push_back(mapper.lookupOrDefault(successor)); 542 543 // Create the new operation. 544 auto *newOp = create(getLoc(), getName(), getResultTypes(), operands, attrs, 545 successors, getNumRegions()); 546 547 // Remember the mapping of any results. 548 for (unsigned i = 0, e = getNumResults(); i != e; ++i) 549 mapper.map(getResult(i), newOp->getResult(i)); 550 551 return newOp; 552 } 553 554 Operation *Operation::cloneWithoutRegions() { 555 BlockAndValueMapping mapper; 556 return cloneWithoutRegions(mapper); 557 } 558 559 /// Create a deep copy of this operation, remapping any operands that use 560 /// values outside of the operation using the map that is provided (leaving 561 /// them alone if no entry is present). Replaces references to cloned 562 /// sub-operations to the corresponding operation that is copied, and adds 563 /// those mappings to the map. 564 Operation *Operation::clone(BlockAndValueMapping &mapper) { 565 auto *newOp = cloneWithoutRegions(mapper); 566 567 // Clone the regions. 568 for (unsigned i = 0; i != numRegions; ++i) 569 getRegion(i).cloneInto(&newOp->getRegion(i), mapper); 570 571 return newOp; 572 } 573 574 Operation *Operation::clone() { 575 BlockAndValueMapping mapper; 576 return clone(mapper); 577 } 578 579 //===----------------------------------------------------------------------===// 580 // OpState trait class. 581 //===----------------------------------------------------------------------===// 582 583 // The fallback for the parser is to try for a dialect operation parser. 584 // Otherwise, reject the custom assembly form. 585 ParseResult OpState::parse(OpAsmParser &parser, OperationState &result) { 586 if (auto parseFn = result.name.getDialect()->getParseOperationHook( 587 result.name.getStringRef())) 588 return (*parseFn)(parser, result); 589 return parser.emitError(parser.getNameLoc(), "has no custom assembly form"); 590 } 591 592 // The fallback for the printer is to try for a dialect operation printer. 593 // Otherwise, it prints the generic form. 594 void OpState::print(Operation *op, OpAsmPrinter &p, StringRef defaultDialect) { 595 if (auto printFn = op->getDialect()->getOperationPrinter(op)) { 596 printOpName(op, p, defaultDialect); 597 printFn(op, p); 598 } else { 599 p.printGenericOp(op); 600 } 601 } 602 603 /// Print an operation name, eliding the dialect prefix if necessary. 604 void OpState::printOpName(Operation *op, OpAsmPrinter &p, 605 StringRef defaultDialect) { 606 StringRef name = op->getName().getStringRef(); 607 if (name.startswith((defaultDialect + ".").str())) 608 name = name.drop_front(defaultDialect.size() + 1); 609 // TODO: remove this special case (and update test/IR/parser.mlir) 610 else if ((defaultDialect.empty() || defaultDialect == "builtin") && 611 name.startswith("func.")) 612 name = name.drop_front(5); 613 p.getStream() << name; 614 } 615 616 /// Emit an error about fatal conditions with this operation, reporting up to 617 /// any diagnostic handlers that may be listening. 618 InFlightDiagnostic OpState::emitError(const Twine &message) { 619 return getOperation()->emitError(message); 620 } 621 622 /// Emit an error with the op name prefixed, like "'dim' op " which is 623 /// convenient for verifiers. 624 InFlightDiagnostic OpState::emitOpError(const Twine &message) { 625 return getOperation()->emitOpError(message); 626 } 627 628 /// Emit a warning about this operation, reporting up to any diagnostic 629 /// handlers that may be listening. 630 InFlightDiagnostic OpState::emitWarning(const Twine &message) { 631 return getOperation()->emitWarning(message); 632 } 633 634 /// Emit a remark about this operation, reporting up to any diagnostic 635 /// handlers that may be listening. 636 InFlightDiagnostic OpState::emitRemark(const Twine &message) { 637 return getOperation()->emitRemark(message); 638 } 639 640 //===----------------------------------------------------------------------===// 641 // Op Trait implementations 642 //===----------------------------------------------------------------------===// 643 644 OpFoldResult OpTrait::impl::foldIdempotent(Operation *op) { 645 if (op->getNumOperands() == 1) { 646 auto *argumentOp = op->getOperand(0).getDefiningOp(); 647 if (argumentOp && op->getName() == argumentOp->getName()) { 648 // Replace the outer operation output with the inner operation. 649 return op->getOperand(0); 650 } 651 } else if (op->getOperand(0) == op->getOperand(1)) { 652 return op->getOperand(0); 653 } 654 655 return {}; 656 } 657 658 OpFoldResult OpTrait::impl::foldInvolution(Operation *op) { 659 auto *argumentOp = op->getOperand(0).getDefiningOp(); 660 if (argumentOp && op->getName() == argumentOp->getName()) { 661 // Replace the outer involutions output with inner's input. 662 return argumentOp->getOperand(0); 663 } 664 665 return {}; 666 } 667 668 LogicalResult OpTrait::impl::verifyZeroOperands(Operation *op) { 669 if (op->getNumOperands() != 0) 670 return op->emitOpError() << "requires zero operands"; 671 return success(); 672 } 673 674 LogicalResult OpTrait::impl::verifyOneOperand(Operation *op) { 675 if (op->getNumOperands() != 1) 676 return op->emitOpError() << "requires a single operand"; 677 return success(); 678 } 679 680 LogicalResult OpTrait::impl::verifyNOperands(Operation *op, 681 unsigned numOperands) { 682 if (op->getNumOperands() != numOperands) { 683 return op->emitOpError() << "expected " << numOperands 684 << " operands, but found " << op->getNumOperands(); 685 } 686 return success(); 687 } 688 689 LogicalResult OpTrait::impl::verifyAtLeastNOperands(Operation *op, 690 unsigned numOperands) { 691 if (op->getNumOperands() < numOperands) 692 return op->emitOpError() 693 << "expected " << numOperands << " or more operands, but found " 694 << op->getNumOperands(); 695 return success(); 696 } 697 698 /// If this is a vector type, or a tensor type, return the scalar element type 699 /// that it is built around, otherwise return the type unmodified. 700 static Type getTensorOrVectorElementType(Type type) { 701 if (auto vec = type.dyn_cast<VectorType>()) 702 return vec.getElementType(); 703 704 // Look through tensor<vector<...>> to find the underlying element type. 705 if (auto tensor = type.dyn_cast<TensorType>()) 706 return getTensorOrVectorElementType(tensor.getElementType()); 707 return type; 708 } 709 710 LogicalResult OpTrait::impl::verifyIsIdempotent(Operation *op) { 711 // FIXME: Add back check for no side effects on operation. 712 // Currently adding it would cause the shared library build 713 // to fail since there would be a dependency of IR on SideEffectInterfaces 714 // which is cyclical. 715 return success(); 716 } 717 718 LogicalResult OpTrait::impl::verifyIsInvolution(Operation *op) { 719 // FIXME: Add back check for no side effects on operation. 720 // Currently adding it would cause the shared library build 721 // to fail since there would be a dependency of IR on SideEffectInterfaces 722 // which is cyclical. 723 return success(); 724 } 725 726 LogicalResult 727 OpTrait::impl::verifyOperandsAreSignlessIntegerLike(Operation *op) { 728 for (auto opType : op->getOperandTypes()) { 729 auto type = getTensorOrVectorElementType(opType); 730 if (!type.isSignlessIntOrIndex()) 731 return op->emitOpError() << "requires an integer or index type"; 732 } 733 return success(); 734 } 735 736 LogicalResult OpTrait::impl::verifyOperandsAreFloatLike(Operation *op) { 737 for (auto opType : op->getOperandTypes()) { 738 auto type = getTensorOrVectorElementType(opType); 739 if (!type.isa<FloatType>()) 740 return op->emitOpError("requires a float type"); 741 } 742 return success(); 743 } 744 745 LogicalResult OpTrait::impl::verifySameTypeOperands(Operation *op) { 746 // Zero or one operand always have the "same" type. 747 unsigned nOperands = op->getNumOperands(); 748 if (nOperands < 2) 749 return success(); 750 751 auto type = op->getOperand(0).getType(); 752 for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) 753 if (opType != type) 754 return op->emitOpError() << "requires all operands to have the same type"; 755 return success(); 756 } 757 758 LogicalResult OpTrait::impl::verifyZeroRegion(Operation *op) { 759 if (op->getNumRegions() != 0) 760 return op->emitOpError() << "requires zero regions"; 761 return success(); 762 } 763 764 LogicalResult OpTrait::impl::verifyOneRegion(Operation *op) { 765 if (op->getNumRegions() != 1) 766 return op->emitOpError() << "requires one region"; 767 return success(); 768 } 769 770 LogicalResult OpTrait::impl::verifyNRegions(Operation *op, 771 unsigned numRegions) { 772 if (op->getNumRegions() != numRegions) 773 return op->emitOpError() << "expected " << numRegions << " regions"; 774 return success(); 775 } 776 777 LogicalResult OpTrait::impl::verifyAtLeastNRegions(Operation *op, 778 unsigned numRegions) { 779 if (op->getNumRegions() < numRegions) 780 return op->emitOpError() << "expected " << numRegions << " or more regions"; 781 return success(); 782 } 783 784 LogicalResult OpTrait::impl::verifyZeroResult(Operation *op) { 785 if (op->getNumResults() != 0) 786 return op->emitOpError() << "requires zero results"; 787 return success(); 788 } 789 790 LogicalResult OpTrait::impl::verifyOneResult(Operation *op) { 791 if (op->getNumResults() != 1) 792 return op->emitOpError() << "requires one result"; 793 return success(); 794 } 795 796 LogicalResult OpTrait::impl::verifyNResults(Operation *op, 797 unsigned numOperands) { 798 if (op->getNumResults() != numOperands) 799 return op->emitOpError() << "expected " << numOperands << " results"; 800 return success(); 801 } 802 803 LogicalResult OpTrait::impl::verifyAtLeastNResults(Operation *op, 804 unsigned numOperands) { 805 if (op->getNumResults() < numOperands) 806 return op->emitOpError() 807 << "expected " << numOperands << " or more results"; 808 return success(); 809 } 810 811 LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) { 812 if (failed(verifyAtLeastNOperands(op, 1))) 813 return failure(); 814 815 if (failed(verifyCompatibleShapes(op->getOperandTypes()))) 816 return op->emitOpError() << "requires the same shape for all operands"; 817 818 return success(); 819 } 820 821 LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) { 822 if (failed(verifyAtLeastNOperands(op, 1)) || 823 failed(verifyAtLeastNResults(op, 1))) 824 return failure(); 825 826 SmallVector<Type, 8> types(op->getOperandTypes()); 827 types.append(llvm::to_vector<4>(op->getResultTypes())); 828 829 if (failed(verifyCompatibleShapes(types))) 830 return op->emitOpError() 831 << "requires the same shape for all operands and results"; 832 833 return success(); 834 } 835 836 LogicalResult OpTrait::impl::verifySameOperandsElementType(Operation *op) { 837 if (failed(verifyAtLeastNOperands(op, 1))) 838 return failure(); 839 auto elementType = getElementTypeOrSelf(op->getOperand(0)); 840 841 for (auto operand : llvm::drop_begin(op->getOperands(), 1)) { 842 if (getElementTypeOrSelf(operand) != elementType) 843 return op->emitOpError("requires the same element type for all operands"); 844 } 845 846 return success(); 847 } 848 849 LogicalResult 850 OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) { 851 if (failed(verifyAtLeastNOperands(op, 1)) || 852 failed(verifyAtLeastNResults(op, 1))) 853 return failure(); 854 855 auto elementType = getElementTypeOrSelf(op->getResult(0)); 856 857 // Verify result element type matches first result's element type. 858 for (auto result : llvm::drop_begin(op->getResults(), 1)) { 859 if (getElementTypeOrSelf(result) != elementType) 860 return op->emitOpError( 861 "requires the same element type for all operands and results"); 862 } 863 864 // Verify operand's element type matches first result's element type. 865 for (auto operand : op->getOperands()) { 866 if (getElementTypeOrSelf(operand) != elementType) 867 return op->emitOpError( 868 "requires the same element type for all operands and results"); 869 } 870 871 return success(); 872 } 873 874 LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) { 875 if (failed(verifyAtLeastNOperands(op, 1)) || 876 failed(verifyAtLeastNResults(op, 1))) 877 return failure(); 878 879 auto type = op->getResult(0).getType(); 880 auto elementType = getElementTypeOrSelf(type); 881 for (auto resultType : llvm::drop_begin(op->getResultTypes())) { 882 if (getElementTypeOrSelf(resultType) != elementType || 883 failed(verifyCompatibleShape(resultType, type))) 884 return op->emitOpError() 885 << "requires the same type for all operands and results"; 886 } 887 for (auto opType : op->getOperandTypes()) { 888 if (getElementTypeOrSelf(opType) != elementType || 889 failed(verifyCompatibleShape(opType, type))) 890 return op->emitOpError() 891 << "requires the same type for all operands and results"; 892 } 893 return success(); 894 } 895 896 LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) { 897 Block *block = op->getBlock(); 898 // Verify that the operation is at the end of the respective parent block. 899 if (!block || &block->back() != op) 900 return op->emitOpError("must be the last operation in the parent block"); 901 return success(); 902 } 903 904 static LogicalResult verifyTerminatorSuccessors(Operation *op) { 905 auto *parent = op->getParentRegion(); 906 907 // Verify that the operands lines up with the BB arguments in the successor. 908 for (Block *succ : op->getSuccessors()) 909 if (succ->getParent() != parent) 910 return op->emitError("reference to block defined in another region"); 911 return success(); 912 } 913 914 LogicalResult OpTrait::impl::verifyZeroSuccessor(Operation *op) { 915 if (op->getNumSuccessors() != 0) { 916 return op->emitOpError("requires 0 successors but found ") 917 << op->getNumSuccessors(); 918 } 919 return success(); 920 } 921 922 LogicalResult OpTrait::impl::verifyOneSuccessor(Operation *op) { 923 if (op->getNumSuccessors() != 1) { 924 return op->emitOpError("requires 1 successor but found ") 925 << op->getNumSuccessors(); 926 } 927 return verifyTerminatorSuccessors(op); 928 } 929 LogicalResult OpTrait::impl::verifyNSuccessors(Operation *op, 930 unsigned numSuccessors) { 931 if (op->getNumSuccessors() != numSuccessors) { 932 return op->emitOpError("requires ") 933 << numSuccessors << " successors but found " 934 << op->getNumSuccessors(); 935 } 936 return verifyTerminatorSuccessors(op); 937 } 938 LogicalResult OpTrait::impl::verifyAtLeastNSuccessors(Operation *op, 939 unsigned numSuccessors) { 940 if (op->getNumSuccessors() < numSuccessors) { 941 return op->emitOpError("requires at least ") 942 << numSuccessors << " successors but found " 943 << op->getNumSuccessors(); 944 } 945 return verifyTerminatorSuccessors(op); 946 } 947 948 LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) { 949 for (auto resultType : op->getResultTypes()) { 950 auto elementType = getTensorOrVectorElementType(resultType); 951 bool isBoolType = elementType.isInteger(1); 952 if (!isBoolType) 953 return op->emitOpError() << "requires a bool result type"; 954 } 955 956 return success(); 957 } 958 959 LogicalResult OpTrait::impl::verifyResultsAreFloatLike(Operation *op) { 960 for (auto resultType : op->getResultTypes()) 961 if (!getTensorOrVectorElementType(resultType).isa<FloatType>()) 962 return op->emitOpError() << "requires a floating point type"; 963 964 return success(); 965 } 966 967 LogicalResult 968 OpTrait::impl::verifyResultsAreSignlessIntegerLike(Operation *op) { 969 for (auto resultType : op->getResultTypes()) 970 if (!getTensorOrVectorElementType(resultType).isSignlessIntOrIndex()) 971 return op->emitOpError() << "requires an integer or index type"; 972 return success(); 973 } 974 975 LogicalResult OpTrait::impl::verifyValueSizeAttr(Operation *op, 976 StringRef attrName, 977 StringRef valueGroupName, 978 size_t expectedCount) { 979 auto sizeAttr = op->getAttrOfType<DenseIntElementsAttr>(attrName); 980 if (!sizeAttr) 981 return op->emitOpError("requires 1D i32 elements attribute '") 982 << attrName << "'"; 983 984 auto sizeAttrType = sizeAttr.getType(); 985 if (sizeAttrType.getRank() != 1 || 986 !sizeAttrType.getElementType().isInteger(32)) 987 return op->emitOpError("requires 1D i32 elements attribute '") 988 << attrName << "'"; 989 990 if (llvm::any_of(sizeAttr.getValues<APInt>(), [](const APInt &element) { 991 return !element.isNonNegative(); 992 })) 993 return op->emitOpError("'") 994 << attrName << "' attribute cannot have negative elements"; 995 996 size_t totalCount = std::accumulate( 997 sizeAttr.begin(), sizeAttr.end(), 0, 998 [](unsigned all, const APInt &one) { return all + one.getZExtValue(); }); 999 1000 if (totalCount != expectedCount) 1001 return op->emitOpError() 1002 << valueGroupName << " count (" << expectedCount 1003 << ") does not match with the total size (" << totalCount 1004 << ") specified in attribute '" << attrName << "'"; 1005 return success(); 1006 } 1007 1008 LogicalResult OpTrait::impl::verifyOperandSizeAttr(Operation *op, 1009 StringRef attrName) { 1010 return verifyValueSizeAttr(op, attrName, "operand", op->getNumOperands()); 1011 } 1012 1013 LogicalResult OpTrait::impl::verifyResultSizeAttr(Operation *op, 1014 StringRef attrName) { 1015 return verifyValueSizeAttr(op, attrName, "result", op->getNumResults()); 1016 } 1017 1018 LogicalResult OpTrait::impl::verifyNoRegionArguments(Operation *op) { 1019 for (Region ®ion : op->getRegions()) { 1020 if (region.empty()) 1021 continue; 1022 1023 if (region.getNumArguments() != 0) { 1024 if (op->getNumRegions() > 1) 1025 return op->emitOpError("region #") 1026 << region.getRegionNumber() << " should have no arguments"; 1027 return op->emitOpError("region should have no arguments"); 1028 } 1029 } 1030 return success(); 1031 } 1032 1033 LogicalResult OpTrait::impl::verifyElementwise(Operation *op) { 1034 auto isMappableType = [](Type type) { 1035 return type.isa<VectorType, TensorType>(); 1036 }; 1037 auto resultMappableTypes = llvm::to_vector<1>( 1038 llvm::make_filter_range(op->getResultTypes(), isMappableType)); 1039 auto operandMappableTypes = llvm::to_vector<2>( 1040 llvm::make_filter_range(op->getOperandTypes(), isMappableType)); 1041 1042 // If the op only has scalar operand/result types, then we have nothing to 1043 // check. 1044 if (resultMappableTypes.empty() && operandMappableTypes.empty()) 1045 return success(); 1046 1047 if (!resultMappableTypes.empty() && operandMappableTypes.empty()) 1048 return op->emitOpError("if a result is non-scalar, then at least one " 1049 "operand must be non-scalar"); 1050 1051 assert(!operandMappableTypes.empty()); 1052 1053 if (resultMappableTypes.empty()) 1054 return op->emitOpError("if an operand is non-scalar, then there must be at " 1055 "least one non-scalar result"); 1056 1057 if (resultMappableTypes.size() != op->getNumResults()) 1058 return op->emitOpError( 1059 "if an operand is non-scalar, then all results must be non-scalar"); 1060 1061 SmallVector<Type, 4> types = llvm::to_vector<2>( 1062 llvm::concat<Type>(operandMappableTypes, resultMappableTypes)); 1063 TypeID expectedBaseTy = types.front().getTypeID(); 1064 if (!llvm::all_of(types, 1065 [&](Type t) { return t.getTypeID() == expectedBaseTy; }) || 1066 failed(verifyCompatibleShapes(types))) { 1067 return op->emitOpError() << "all non-scalar operands/results must have the " 1068 "same shape and base type"; 1069 } 1070 1071 return success(); 1072 } 1073 1074 /// Check for any values used by operations regions attached to the 1075 /// specified "IsIsolatedFromAbove" operation defined outside of it. 1076 LogicalResult OpTrait::impl::verifyIsIsolatedFromAbove(Operation *isolatedOp) { 1077 assert(isolatedOp->hasTrait<OpTrait::IsIsolatedFromAbove>() && 1078 "Intended to check IsolatedFromAbove ops"); 1079 1080 // List of regions to analyze. Each region is processed independently, with 1081 // respect to the common `limit` region, so we can look at them in any order. 1082 // Therefore, use a simple vector and push/pop back the current region. 1083 SmallVector<Region *, 8> pendingRegions; 1084 for (auto ®ion : isolatedOp->getRegions()) { 1085 pendingRegions.push_back(®ion); 1086 1087 // Traverse all operations in the region. 1088 while (!pendingRegions.empty()) { 1089 for (Operation &op : pendingRegions.pop_back_val()->getOps()) { 1090 for (Value operand : op.getOperands()) { 1091 // operand should be non-null here if the IR is well-formed. But 1092 // we don't assert here as this function is called from the verifier 1093 // and so could be called on invalid IR. 1094 if (!operand) 1095 return op.emitOpError("operation's operand is null"); 1096 1097 // Check that any value that is used by an operation is defined in the 1098 // same region as either an operation result. 1099 auto *operandRegion = operand.getParentRegion(); 1100 if (!region.isAncestor(operandRegion)) { 1101 return op.emitOpError("using value defined outside the region") 1102 .attachNote(isolatedOp->getLoc()) 1103 << "required by region isolation constraints"; 1104 } 1105 } 1106 1107 // Schedule any regions in the operation for further checking. Don't 1108 // recurse into other IsolatedFromAbove ops, because they will check 1109 // themselves. 1110 if (op.getNumRegions() && 1111 !op.hasTrait<OpTrait::IsIsolatedFromAbove>()) { 1112 for (Region &subRegion : op.getRegions()) 1113 pendingRegions.push_back(&subRegion); 1114 } 1115 } 1116 } 1117 } 1118 1119 return success(); 1120 } 1121 1122 bool OpTrait::hasElementwiseMappableTraits(Operation *op) { 1123 return op->hasTrait<Elementwise>() && op->hasTrait<Scalarizable>() && 1124 op->hasTrait<Vectorizable>() && op->hasTrait<Tensorizable>(); 1125 } 1126 1127 //===----------------------------------------------------------------------===// 1128 // CastOpInterface 1129 //===----------------------------------------------------------------------===// 1130 1131 /// Attempt to fold the given cast operation. 1132 LogicalResult 1133 impl::foldCastInterfaceOp(Operation *op, ArrayRef<Attribute> attrOperands, 1134 SmallVectorImpl<OpFoldResult> &foldResults) { 1135 OperandRange operands = op->getOperands(); 1136 if (operands.empty()) 1137 return failure(); 1138 ResultRange results = op->getResults(); 1139 1140 // Check for the case where the input and output types match 1-1. 1141 if (operands.getTypes() == results.getTypes()) { 1142 foldResults.append(operands.begin(), operands.end()); 1143 return success(); 1144 } 1145 1146 return failure(); 1147 } 1148 1149 /// Attempt to verify the given cast operation. 1150 LogicalResult impl::verifyCastInterfaceOp( 1151 Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible) { 1152 auto resultTypes = op->getResultTypes(); 1153 if (llvm::empty(resultTypes)) 1154 return op->emitOpError() 1155 << "expected at least one result for cast operation"; 1156 1157 auto operandTypes = op->getOperandTypes(); 1158 if (!areCastCompatible(operandTypes, resultTypes)) { 1159 InFlightDiagnostic diag = op->emitOpError("operand type"); 1160 if (llvm::empty(operandTypes)) 1161 diag << "s []"; 1162 else if (llvm::size(operandTypes) == 1) 1163 diag << " " << *operandTypes.begin(); 1164 else 1165 diag << "s " << operandTypes; 1166 return diag << " and result type" << (resultTypes.size() == 1 ? " " : "s ") 1167 << resultTypes << " are cast incompatible"; 1168 } 1169 1170 return success(); 1171 } 1172 1173 //===----------------------------------------------------------------------===// 1174 // Misc. utils 1175 //===----------------------------------------------------------------------===// 1176 1177 /// Insert an operation, generated by `buildTerminatorOp`, at the end of the 1178 /// region's only block if it does not have a terminator already. If the region 1179 /// is empty, insert a new block first. `buildTerminatorOp` should return the 1180 /// terminator operation to insert. 1181 void impl::ensureRegionTerminator( 1182 Region ®ion, OpBuilder &builder, Location loc, 1183 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { 1184 OpBuilder::InsertionGuard guard(builder); 1185 if (region.empty()) 1186 builder.createBlock(®ion); 1187 1188 Block &block = region.back(); 1189 if (!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>()) 1190 return; 1191 1192 builder.setInsertionPointToEnd(&block); 1193 builder.insert(buildTerminatorOp(builder, loc)); 1194 } 1195 1196 /// Create a simple OpBuilder and forward to the OpBuilder version of this 1197 /// function. 1198 void impl::ensureRegionTerminator( 1199 Region ®ion, Builder &builder, Location loc, 1200 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { 1201 OpBuilder opBuilder(builder.getContext()); 1202 ensureRegionTerminator(region, opBuilder, loc, buildTerminatorOp); 1203 } 1204