1 //===- Operation.cpp - Operation support code -----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Operation.h" 10 #include "mlir/IR/BlockAndValueMapping.h" 11 #include "mlir/IR/Dialect.h" 12 #include "mlir/IR/OpImplementation.h" 13 #include "mlir/IR/PatternMatch.h" 14 #include "mlir/IR/StandardTypes.h" 15 #include "mlir/IR/TypeUtilities.h" 16 #include <numeric> 17 18 using namespace mlir; 19 20 OpAsmParser::~OpAsmParser() {} 21 22 //===----------------------------------------------------------------------===// 23 // OperationName 24 //===----------------------------------------------------------------------===// 25 26 /// Form the OperationName for an op with the specified string. This either is 27 /// a reference to an AbstractOperation if one is known, or a uniqued Identifier 28 /// if not. 29 OperationName::OperationName(StringRef name, MLIRContext *context) { 30 if (auto *op = AbstractOperation::lookup(name, context)) 31 representation = op; 32 else 33 representation = Identifier::get(name, context); 34 } 35 36 /// Return the name of the dialect this operation is registered to. 37 StringRef OperationName::getDialect() const { 38 return getStringRef().split('.').first; 39 } 40 41 /// Return the name of this operation. This always succeeds. 42 StringRef OperationName::getStringRef() const { 43 if (auto *op = representation.dyn_cast<const AbstractOperation *>()) 44 return op->name; 45 return representation.get<Identifier>().strref(); 46 } 47 48 const AbstractOperation *OperationName::getAbstractOperation() const { 49 return representation.dyn_cast<const AbstractOperation *>(); 50 } 51 52 OperationName OperationName::getFromOpaquePointer(void *pointer) { 53 return OperationName(RepresentationUnion::getFromOpaqueValue(pointer)); 54 } 55 56 //===----------------------------------------------------------------------===// 57 // Operation 58 //===----------------------------------------------------------------------===// 59 60 /// Create a new Operation with the specific fields. 61 Operation *Operation::create(Location location, OperationName name, 62 ArrayRef<Type> resultTypes, 63 ArrayRef<Value> operands, 64 ArrayRef<NamedAttribute> attributes, 65 ArrayRef<Block *> successors, unsigned numRegions, 66 bool resizableOperandList) { 67 return create(location, name, resultTypes, operands, 68 NamedAttributeList(attributes), successors, numRegions, 69 resizableOperandList); 70 } 71 72 /// Create a new Operation from operation state. 73 Operation *Operation::create(const OperationState &state) { 74 return Operation::create(state.location, state.name, state.types, 75 state.operands, NamedAttributeList(state.attributes), 76 state.successors, state.regions, 77 state.resizableOperandList); 78 } 79 80 /// Create a new Operation with the specific fields. 81 Operation *Operation::create(Location location, OperationName name, 82 ArrayRef<Type> resultTypes, 83 ArrayRef<Value> operands, 84 NamedAttributeList attributes, 85 ArrayRef<Block *> successors, RegionRange regions, 86 bool resizableOperandList) { 87 unsigned numRegions = regions.size(); 88 Operation *op = create(location, name, resultTypes, operands, attributes, 89 successors, numRegions, resizableOperandList); 90 for (unsigned i = 0; i < numRegions; ++i) 91 if (regions[i]) 92 op->getRegion(i).takeBody(*regions[i]); 93 return op; 94 } 95 96 /// Overload of create that takes an existing NamedAttributeList to avoid 97 /// unnecessarily uniquing a list of attributes. 98 Operation *Operation::create(Location location, OperationName name, 99 ArrayRef<Type> resultTypes, 100 ArrayRef<Value> operands, 101 NamedAttributeList attributes, 102 ArrayRef<Block *> successors, unsigned numRegions, 103 bool resizableOperandList) { 104 // We only need to allocate additional memory for a subset of results. 105 unsigned numTrailingResults = OpResult::getNumTrailing(resultTypes.size()); 106 unsigned numSuccessors = successors.size(); 107 unsigned numOperands = operands.size(); 108 109 // Compute the byte size for the operation and the operand storage. 110 auto byteSize = totalSizeToAlloc<detail::TrailingOpResult, BlockOperand, 111 Region, detail::OperandStorage>( 112 numTrailingResults, numSuccessors, numRegions, 113 /*detail::OperandStorage*/ 1); 114 byteSize += llvm::alignTo(detail::OperandStorage::additionalAllocSize( 115 numOperands, resizableOperandList), 116 alignof(Operation)); 117 void *rawMem = malloc(byteSize); 118 119 // Create the new Operation. 120 auto op = ::new (rawMem) Operation(location, name, resultTypes, numSuccessors, 121 numRegions, attributes); 122 123 assert((numSuccessors == 0 || !op->isKnownNonTerminator()) && 124 "unexpected successors in a non-terminator operation"); 125 126 // Initialize the trailing results. 127 if (LLVM_UNLIKELY(numTrailingResults > 0)) { 128 // We initialize the trailing results with their result number. This makes 129 // 'getResultNumber' checks much more efficient. The main purpose for these 130 // results is to give an anchor to the main operation anyways, so this is 131 // purely an optimization. 132 auto *trailingResultIt = op->getTrailingObjects<detail::TrailingOpResult>(); 133 for (unsigned i = 0; i != numTrailingResults; ++i, ++trailingResultIt) 134 trailingResultIt->trailingResultNumber = i; 135 } 136 137 // Initialize the regions. 138 for (unsigned i = 0; i != numRegions; ++i) 139 new (&op->getRegion(i)) Region(op); 140 141 // Initialize the operands. 142 new (&op->getOperandStorage()) 143 detail::OperandStorage(numOperands, resizableOperandList); 144 auto opOperands = op->getOpOperands(); 145 for (unsigned i = 0; i != numOperands; ++i) 146 new (&opOperands[i]) OpOperand(op, operands[i]); 147 148 // Initialize the successors. 149 auto blockOperands = op->getBlockOperands(); 150 for (unsigned i = 0; i != numSuccessors; ++i) 151 new (&blockOperands[i]) BlockOperand(op, successors[i]); 152 153 return op; 154 } 155 156 Operation::Operation(Location location, OperationName name, 157 ArrayRef<Type> resultTypes, unsigned numSuccessors, 158 unsigned numRegions, const NamedAttributeList &attributes) 159 : location(location), numSuccs(numSuccessors), numRegions(numRegions), 160 hasSingleResult(false), name(name), attrs(attributes) { 161 if (!resultTypes.empty()) { 162 // If there is a single result it is stored in-place, otherwise use a tuple. 163 hasSingleResult = resultTypes.size() == 1; 164 if (hasSingleResult) 165 resultType = resultTypes.front(); 166 else 167 resultType = TupleType::get(resultTypes, location->getContext()); 168 } 169 } 170 171 // Operations are deleted through the destroy() member because they are 172 // allocated via malloc. 173 Operation::~Operation() { 174 assert(block == nullptr && "operation destroyed but still in a block"); 175 176 // Explicitly run the destructors for the operands and results. 177 getOperandStorage().~OperandStorage(); 178 179 // Explicitly run the destructors for the successors. 180 for (auto &successor : getBlockOperands()) 181 successor.~BlockOperand(); 182 183 // Explicitly destroy the regions. 184 for (auto ®ion : getRegions()) 185 region.~Region(); 186 } 187 188 /// Destroy this operation or one of its subclasses. 189 void Operation::destroy() { 190 this->~Operation(); 191 free(this); 192 } 193 194 /// Return the context this operation is associated with. 195 MLIRContext *Operation::getContext() { return location->getContext(); } 196 197 /// Return the dialect this operation is associated with, or nullptr if the 198 /// associated dialect is not registered. 199 Dialect *Operation::getDialect() { 200 if (auto *abstractOp = getAbstractOperation()) 201 return &abstractOp->dialect; 202 203 // If this operation hasn't been registered or doesn't have abstract 204 // operation, try looking up the dialect name in the context. 205 return getContext()->getRegisteredDialect(getName().getDialect()); 206 } 207 208 Region *Operation::getParentRegion() { 209 return block ? block->getParent() : nullptr; 210 } 211 212 Operation *Operation::getParentOp() { 213 return block ? block->getParentOp() : nullptr; 214 } 215 216 /// Return true if this operation is a proper ancestor of the `other` 217 /// operation. 218 bool Operation::isProperAncestor(Operation *other) { 219 while ((other = other->getParentOp())) 220 if (this == other) 221 return true; 222 return false; 223 } 224 225 /// Replace any uses of 'from' with 'to' within this operation. 226 void Operation::replaceUsesOfWith(Value from, Value to) { 227 if (from == to) 228 return; 229 for (auto &operand : getOpOperands()) 230 if (operand.get() == from) 231 operand.set(to); 232 } 233 234 /// Replace the current operands of this operation with the ones provided in 235 /// 'operands'. If the operands list is not resizable, the size of 'operands' 236 /// must be less than or equal to the current number of operands. 237 void Operation::setOperands(ValueRange operands) { 238 getOperandStorage().setOperands(this, operands); 239 } 240 241 //===----------------------------------------------------------------------===// 242 // Diagnostics 243 //===----------------------------------------------------------------------===// 244 245 /// Emit an error about fatal conditions with this operation, reporting up to 246 /// any diagnostic handlers that may be listening. 247 InFlightDiagnostic Operation::emitError(const Twine &message) { 248 InFlightDiagnostic diag = mlir::emitError(getLoc(), message); 249 if (getContext()->shouldPrintOpOnDiagnostic()) { 250 // Print out the operation explicitly here so that we can print the generic 251 // form. 252 // TODO(riverriddle) It would be nice if we could instead provide the 253 // specific printing flags when adding the operation as an argument to the 254 // diagnostic. 255 std::string printedOp; 256 { 257 llvm::raw_string_ostream os(printedOp); 258 print(os, OpPrintingFlags().printGenericOpForm().useLocalScope()); 259 } 260 diag.attachNote(getLoc()) << "see current operation: " << printedOp; 261 } 262 return diag; 263 } 264 265 /// Emit a warning about this operation, reporting up to any diagnostic 266 /// handlers that may be listening. 267 InFlightDiagnostic Operation::emitWarning(const Twine &message) { 268 InFlightDiagnostic diag = mlir::emitWarning(getLoc(), message); 269 if (getContext()->shouldPrintOpOnDiagnostic()) 270 diag.attachNote(getLoc()) << "see current operation: " << *this; 271 return diag; 272 } 273 274 /// Emit a remark about this operation, reporting up to any diagnostic 275 /// handlers that may be listening. 276 InFlightDiagnostic Operation::emitRemark(const Twine &message) { 277 InFlightDiagnostic diag = mlir::emitRemark(getLoc(), message); 278 if (getContext()->shouldPrintOpOnDiagnostic()) 279 diag.attachNote(getLoc()) << "see current operation: " << *this; 280 return diag; 281 } 282 283 //===----------------------------------------------------------------------===// 284 // Operation Ordering 285 //===----------------------------------------------------------------------===// 286 287 constexpr unsigned Operation::kInvalidOrderIdx; 288 constexpr unsigned Operation::kOrderStride; 289 290 /// Given an operation 'other' that is within the same parent block, return 291 /// whether the current operation is before 'other' in the operation list 292 /// of the parent block. 293 /// Note: This function has an average complexity of O(1), but worst case may 294 /// take O(N) where N is the number of operations within the parent block. 295 bool Operation::isBeforeInBlock(Operation *other) { 296 assert(block && "Operations without parent blocks have no order."); 297 assert(other && other->block == block && 298 "Expected other operation to have the same parent block."); 299 // If the order of the block is already invalid, directly recompute the 300 // parent. 301 if (!block->isOpOrderValid()) { 302 block->recomputeOpOrder(); 303 } else { 304 // Update the order either operation if necessary. 305 updateOrderIfNecessary(); 306 other->updateOrderIfNecessary(); 307 } 308 309 return orderIndex < other->orderIndex; 310 } 311 312 /// Update the order index of this operation of this operation if necessary, 313 /// potentially recomputing the order of the parent block. 314 void Operation::updateOrderIfNecessary() { 315 assert(block && "expected valid parent"); 316 317 // If the order is valid for this operation there is nothing to do. 318 if (hasValidOrder()) 319 return; 320 Operation *blockFront = &block->front(); 321 Operation *blockBack = &block->back(); 322 323 // This method is expected to only be invoked on blocks with more than one 324 // operation. 325 assert(blockFront != blockBack && "expected more than one operation"); 326 327 // If the operation is at the end of the block. 328 if (this == blockBack) { 329 Operation *prevNode = getPrevNode(); 330 if (!prevNode->hasValidOrder()) 331 return block->recomputeOpOrder(); 332 333 // Add the stride to the previous operation. 334 orderIndex = prevNode->orderIndex + kOrderStride; 335 return; 336 } 337 338 // If this is the first operation try to use the next operation to compute the 339 // ordering. 340 if (this == blockFront) { 341 Operation *nextNode = getNextNode(); 342 if (!nextNode->hasValidOrder()) 343 return block->recomputeOpOrder(); 344 // There is no order to give this operation. 345 if (nextNode->orderIndex == 0) 346 return block->recomputeOpOrder(); 347 348 // If we can't use the stride, just take the middle value left. This is safe 349 // because we know there is at least one valid index to assign to. 350 if (nextNode->orderIndex <= kOrderStride) 351 orderIndex = (nextNode->orderIndex / 2); 352 else 353 orderIndex = kOrderStride; 354 return; 355 } 356 357 // Otherwise, this operation is between two others. Place this operation in 358 // the middle of the previous and next if possible. 359 Operation *prevNode = getPrevNode(), *nextNode = getNextNode(); 360 if (!prevNode->hasValidOrder() || !nextNode->hasValidOrder()) 361 return block->recomputeOpOrder(); 362 unsigned prevOrder = prevNode->orderIndex, nextOrder = nextNode->orderIndex; 363 364 // Check to see if there is a valid order between the two. 365 if (prevOrder + 1 == nextOrder) 366 return block->recomputeOpOrder(); 367 orderIndex = prevOrder + 1 + ((nextOrder - prevOrder) / 2); 368 } 369 370 //===----------------------------------------------------------------------===// 371 // ilist_traits for Operation 372 //===----------------------------------------------------------------------===// 373 374 auto llvm::ilist_detail::SpecificNodeAccess< 375 typename llvm::ilist_detail::compute_node_options< 376 ::mlir::Operation>::type>::getNodePtr(pointer N) -> node_type * { 377 return NodeAccess::getNodePtr<OptionsT>(N); 378 } 379 380 auto llvm::ilist_detail::SpecificNodeAccess< 381 typename llvm::ilist_detail::compute_node_options< 382 ::mlir::Operation>::type>::getNodePtr(const_pointer N) 383 -> const node_type * { 384 return NodeAccess::getNodePtr<OptionsT>(N); 385 } 386 387 auto llvm::ilist_detail::SpecificNodeAccess< 388 typename llvm::ilist_detail::compute_node_options< 389 ::mlir::Operation>::type>::getValuePtr(node_type *N) -> pointer { 390 return NodeAccess::getValuePtr<OptionsT>(N); 391 } 392 393 auto llvm::ilist_detail::SpecificNodeAccess< 394 typename llvm::ilist_detail::compute_node_options< 395 ::mlir::Operation>::type>::getValuePtr(const node_type *N) 396 -> const_pointer { 397 return NodeAccess::getValuePtr<OptionsT>(N); 398 } 399 400 void llvm::ilist_traits<::mlir::Operation>::deleteNode(Operation *op) { 401 op->destroy(); 402 } 403 404 Block *llvm::ilist_traits<::mlir::Operation>::getContainingBlock() { 405 size_t Offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr)))); 406 iplist<Operation> *Anchor(static_cast<iplist<Operation> *>(this)); 407 return reinterpret_cast<Block *>(reinterpret_cast<char *>(Anchor) - Offset); 408 } 409 410 /// This is a trait method invoked when an operation is added to a block. We 411 /// keep the block pointer up to date. 412 void llvm::ilist_traits<::mlir::Operation>::addNodeToList(Operation *op) { 413 assert(!op->getBlock() && "already in an operation block!"); 414 op->block = getContainingBlock(); 415 416 // Invalidate the order on the operation. 417 op->orderIndex = Operation::kInvalidOrderIdx; 418 } 419 420 /// This is a trait method invoked when an operation is removed from a block. 421 /// We keep the block pointer up to date. 422 void llvm::ilist_traits<::mlir::Operation>::removeNodeFromList(Operation *op) { 423 assert(op->block && "not already in an operation block!"); 424 op->block = nullptr; 425 } 426 427 /// This is a trait method invoked when an operation is moved from one block 428 /// to another. We keep the block pointer up to date. 429 void llvm::ilist_traits<::mlir::Operation>::transferNodesFromList( 430 ilist_traits<Operation> &otherList, op_iterator first, op_iterator last) { 431 Block *curParent = getContainingBlock(); 432 433 // Invalidate the ordering of the parent block. 434 curParent->invalidateOpOrder(); 435 436 // If we are transferring operations within the same block, the block 437 // pointer doesn't need to be updated. 438 if (curParent == otherList.getContainingBlock()) 439 return; 440 441 // Update the 'block' member of each operation. 442 for (; first != last; ++first) 443 first->block = curParent; 444 } 445 446 /// Remove this operation (and its descendants) from its Block and delete 447 /// all of them. 448 void Operation::erase() { 449 if (auto *parent = getBlock()) 450 parent->getOperations().erase(this); 451 else 452 destroy(); 453 } 454 455 /// Unlink this operation from its current block and insert it right before 456 /// `existingOp` which may be in the same or another block in the same 457 /// function. 458 void Operation::moveBefore(Operation *existingOp) { 459 moveBefore(existingOp->getBlock(), existingOp->getIterator()); 460 } 461 462 /// Unlink this operation from its current basic block and insert it right 463 /// before `iterator` in the specified basic block. 464 void Operation::moveBefore(Block *block, 465 llvm::iplist<Operation>::iterator iterator) { 466 block->getOperations().splice(iterator, getBlock()->getOperations(), 467 getIterator()); 468 } 469 470 /// This drops all operand uses from this operation, which is an essential 471 /// step in breaking cyclic dependences between references when they are to 472 /// be deleted. 473 void Operation::dropAllReferences() { 474 for (auto &op : getOpOperands()) 475 op.drop(); 476 477 for (auto ®ion : getRegions()) 478 region.dropAllReferences(); 479 480 for (auto &dest : getBlockOperands()) 481 dest.drop(); 482 } 483 484 /// This drops all uses of any values defined by this operation or its nested 485 /// regions, wherever they are located. 486 void Operation::dropAllDefinedValueUses() { 487 dropAllUses(); 488 489 for (auto ®ion : getRegions()) 490 for (auto &block : region) 491 block.dropAllDefinedValueUses(); 492 } 493 494 /// Return the number of results held by this operation. 495 unsigned Operation::getNumResults() { 496 if (!resultType) 497 return 0; 498 return hasSingleResult ? 1 : resultType.cast<TupleType>().size(); 499 } 500 501 auto Operation::getResultTypes() -> result_type_range { 502 if (!resultType) 503 return llvm::None; 504 if (hasSingleResult) 505 return resultType; 506 return resultType.cast<TupleType>().getTypes(); 507 } 508 509 void Operation::setSuccessor(Block *block, unsigned index) { 510 assert(index < getNumSuccessors()); 511 getBlockOperands()[index].set(block); 512 } 513 514 /// Attempt to fold this operation using the Op's registered foldHook. 515 LogicalResult Operation::fold(ArrayRef<Attribute> operands, 516 SmallVectorImpl<OpFoldResult> &results) { 517 // If we have a registered operation definition matching this one, use it to 518 // try to constant fold the operation. 519 auto *abstractOp = getAbstractOperation(); 520 if (abstractOp && succeeded(abstractOp->foldHook(this, operands, results))) 521 return success(); 522 523 // Otherwise, fall back on the dialect hook to handle it. 524 Dialect *dialect = getDialect(); 525 if (!dialect) 526 return failure(); 527 528 SmallVector<Attribute, 8> constants; 529 if (failed(dialect->constantFoldHook(this, operands, constants))) 530 return failure(); 531 results.assign(constants.begin(), constants.end()); 532 return success(); 533 } 534 535 /// Emit an error with the op name prefixed, like "'dim' op " which is 536 /// convenient for verifiers. 537 InFlightDiagnostic Operation::emitOpError(const Twine &message) { 538 return emitError() << "'" << getName() << "' op " << message; 539 } 540 541 //===----------------------------------------------------------------------===// 542 // Operation Cloning 543 //===----------------------------------------------------------------------===// 544 545 /// Create a deep copy of this operation but keep the operation regions empty. 546 /// Operands are remapped using `mapper` (if present), and `mapper` is updated 547 /// to contain the results. 548 Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper) { 549 SmallVector<Value, 8> operands; 550 SmallVector<Block *, 2> successors; 551 552 // Remap the operands. 553 operands.reserve(getNumOperands()); 554 for (auto opValue : getOperands()) 555 operands.push_back(mapper.lookupOrDefault(opValue)); 556 557 // Remap the successors. 558 successors.reserve(getNumSuccessors()); 559 for (Block *successor : getSuccessors()) 560 successors.push_back(mapper.lookupOrDefault(successor)); 561 562 // Create the new operation. 563 auto *newOp = Operation::create(getLoc(), getName(), getResultTypes(), 564 operands, attrs, successors, getNumRegions(), 565 hasResizableOperandsList()); 566 567 // Remember the mapping of any results. 568 for (unsigned i = 0, e = getNumResults(); i != e; ++i) 569 mapper.map(getResult(i), newOp->getResult(i)); 570 571 return newOp; 572 } 573 574 Operation *Operation::cloneWithoutRegions() { 575 BlockAndValueMapping mapper; 576 return cloneWithoutRegions(mapper); 577 } 578 579 /// Create a deep copy of this operation, remapping any operands that use 580 /// values outside of the operation using the map that is provided (leaving 581 /// them alone if no entry is present). Replaces references to cloned 582 /// sub-operations to the corresponding operation that is copied, and adds 583 /// those mappings to the map. 584 Operation *Operation::clone(BlockAndValueMapping &mapper) { 585 auto *newOp = cloneWithoutRegions(mapper); 586 587 // Clone the regions. 588 for (unsigned i = 0; i != numRegions; ++i) 589 getRegion(i).cloneInto(&newOp->getRegion(i), mapper); 590 591 return newOp; 592 } 593 594 Operation *Operation::clone() { 595 BlockAndValueMapping mapper; 596 return clone(mapper); 597 } 598 599 //===----------------------------------------------------------------------===// 600 // OpState trait class. 601 //===----------------------------------------------------------------------===// 602 603 // The fallback for the parser is to reject the custom assembly form. 604 ParseResult OpState::parse(OpAsmParser &parser, OperationState &result) { 605 return parser.emitError(parser.getNameLoc(), "has no custom assembly form"); 606 } 607 608 // The fallback for the printer is to print in the generic assembly form. 609 void OpState::print(OpAsmPrinter &p) { p.printGenericOp(getOperation()); } 610 611 /// Emit an error about fatal conditions with this operation, reporting up to 612 /// any diagnostic handlers that may be listening. 613 InFlightDiagnostic OpState::emitError(const Twine &message) { 614 return getOperation()->emitError(message); 615 } 616 617 /// Emit an error with the op name prefixed, like "'dim' op " which is 618 /// convenient for verifiers. 619 InFlightDiagnostic OpState::emitOpError(const Twine &message) { 620 return getOperation()->emitOpError(message); 621 } 622 623 /// Emit a warning about this operation, reporting up to any diagnostic 624 /// handlers that may be listening. 625 InFlightDiagnostic OpState::emitWarning(const Twine &message) { 626 return getOperation()->emitWarning(message); 627 } 628 629 /// Emit a remark about this operation, reporting up to any diagnostic 630 /// handlers that may be listening. 631 InFlightDiagnostic OpState::emitRemark(const Twine &message) { 632 return getOperation()->emitRemark(message); 633 } 634 635 //===----------------------------------------------------------------------===// 636 // Op Trait implementations 637 //===----------------------------------------------------------------------===// 638 639 LogicalResult OpTrait::impl::verifyZeroOperands(Operation *op) { 640 if (op->getNumOperands() != 0) 641 return op->emitOpError() << "requires zero operands"; 642 return success(); 643 } 644 645 LogicalResult OpTrait::impl::verifyOneOperand(Operation *op) { 646 if (op->getNumOperands() != 1) 647 return op->emitOpError() << "requires a single operand"; 648 return success(); 649 } 650 651 LogicalResult OpTrait::impl::verifyNOperands(Operation *op, 652 unsigned numOperands) { 653 if (op->getNumOperands() != numOperands) { 654 return op->emitOpError() << "expected " << numOperands 655 << " operands, but found " << op->getNumOperands(); 656 } 657 return success(); 658 } 659 660 LogicalResult OpTrait::impl::verifyAtLeastNOperands(Operation *op, 661 unsigned numOperands) { 662 if (op->getNumOperands() < numOperands) 663 return op->emitOpError() 664 << "expected " << numOperands << " or more operands"; 665 return success(); 666 } 667 668 /// If this is a vector type, or a tensor type, return the scalar element type 669 /// that it is built around, otherwise return the type unmodified. 670 static Type getTensorOrVectorElementType(Type type) { 671 if (auto vec = type.dyn_cast<VectorType>()) 672 return vec.getElementType(); 673 674 // Look through tensor<vector<...>> to find the underlying element type. 675 if (auto tensor = type.dyn_cast<TensorType>()) 676 return getTensorOrVectorElementType(tensor.getElementType()); 677 return type; 678 } 679 680 LogicalResult 681 OpTrait::impl::verifyOperandsAreSignlessIntegerLike(Operation *op) { 682 for (auto opType : op->getOperandTypes()) { 683 auto type = getTensorOrVectorElementType(opType); 684 if (!type.isSignlessIntOrIndex()) 685 return op->emitOpError() << "requires an integer or index type"; 686 } 687 return success(); 688 } 689 690 LogicalResult OpTrait::impl::verifyOperandsAreFloatLike(Operation *op) { 691 for (auto opType : op->getOperandTypes()) { 692 auto type = getTensorOrVectorElementType(opType); 693 if (!type.isa<FloatType>()) 694 return op->emitOpError("requires a float type"); 695 } 696 return success(); 697 } 698 699 LogicalResult OpTrait::impl::verifySameTypeOperands(Operation *op) { 700 // Zero or one operand always have the "same" type. 701 unsigned nOperands = op->getNumOperands(); 702 if (nOperands < 2) 703 return success(); 704 705 auto type = op->getOperand(0).getType(); 706 for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) 707 if (opType != type) 708 return op->emitOpError() << "requires all operands to have the same type"; 709 return success(); 710 } 711 712 LogicalResult OpTrait::impl::verifyZeroRegion(Operation *op) { 713 if (op->getNumRegions() != 0) 714 return op->emitOpError() << "requires zero regions"; 715 return success(); 716 } 717 718 LogicalResult OpTrait::impl::verifyOneRegion(Operation *op) { 719 if (op->getNumRegions() != 1) 720 return op->emitOpError() << "requires one region"; 721 return success(); 722 } 723 724 LogicalResult OpTrait::impl::verifyNRegions(Operation *op, 725 unsigned numRegions) { 726 if (op->getNumRegions() != numRegions) 727 return op->emitOpError() << "expected " << numRegions << " regions"; 728 return success(); 729 } 730 731 LogicalResult OpTrait::impl::verifyAtLeastNRegions(Operation *op, 732 unsigned numRegions) { 733 if (op->getNumRegions() < numRegions) 734 return op->emitOpError() << "expected " << numRegions << " or more regions"; 735 return success(); 736 } 737 738 LogicalResult OpTrait::impl::verifyZeroResult(Operation *op) { 739 if (op->getNumResults() != 0) 740 return op->emitOpError() << "requires zero results"; 741 return success(); 742 } 743 744 LogicalResult OpTrait::impl::verifyOneResult(Operation *op) { 745 if (op->getNumResults() != 1) 746 return op->emitOpError() << "requires one result"; 747 return success(); 748 } 749 750 LogicalResult OpTrait::impl::verifyNResults(Operation *op, 751 unsigned numOperands) { 752 if (op->getNumResults() != numOperands) 753 return op->emitOpError() << "expected " << numOperands << " results"; 754 return success(); 755 } 756 757 LogicalResult OpTrait::impl::verifyAtLeastNResults(Operation *op, 758 unsigned numOperands) { 759 if (op->getNumResults() < numOperands) 760 return op->emitOpError() 761 << "expected " << numOperands << " or more results"; 762 return success(); 763 } 764 765 LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) { 766 if (failed(verifyAtLeastNOperands(op, 1))) 767 return failure(); 768 769 auto type = op->getOperand(0).getType(); 770 for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) { 771 if (failed(verifyCompatibleShape(opType, type))) 772 return op->emitOpError() << "requires the same shape for all operands"; 773 } 774 return success(); 775 } 776 777 LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) { 778 if (failed(verifyAtLeastNOperands(op, 1)) || 779 failed(verifyAtLeastNResults(op, 1))) 780 return failure(); 781 782 auto type = op->getOperand(0).getType(); 783 for (auto resultType : op->getResultTypes()) { 784 if (failed(verifyCompatibleShape(resultType, type))) 785 return op->emitOpError() 786 << "requires the same shape for all operands and results"; 787 } 788 for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) { 789 if (failed(verifyCompatibleShape(opType, type))) 790 return op->emitOpError() 791 << "requires the same shape for all operands and results"; 792 } 793 return success(); 794 } 795 796 LogicalResult OpTrait::impl::verifySameOperandsElementType(Operation *op) { 797 if (failed(verifyAtLeastNOperands(op, 1))) 798 return failure(); 799 auto elementType = getElementTypeOrSelf(op->getOperand(0)); 800 801 for (auto operand : llvm::drop_begin(op->getOperands(), 1)) { 802 if (getElementTypeOrSelf(operand) != elementType) 803 return op->emitOpError("requires the same element type for all operands"); 804 } 805 806 return success(); 807 } 808 809 LogicalResult 810 OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) { 811 if (failed(verifyAtLeastNOperands(op, 1)) || 812 failed(verifyAtLeastNResults(op, 1))) 813 return failure(); 814 815 auto elementType = getElementTypeOrSelf(op->getResult(0)); 816 817 // Verify result element type matches first result's element type. 818 for (auto result : llvm::drop_begin(op->getResults(), 1)) { 819 if (getElementTypeOrSelf(result) != elementType) 820 return op->emitOpError( 821 "requires the same element type for all operands and results"); 822 } 823 824 // Verify operand's element type matches first result's element type. 825 for (auto operand : op->getOperands()) { 826 if (getElementTypeOrSelf(operand) != elementType) 827 return op->emitOpError( 828 "requires the same element type for all operands and results"); 829 } 830 831 return success(); 832 } 833 834 LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) { 835 if (failed(verifyAtLeastNOperands(op, 1)) || 836 failed(verifyAtLeastNResults(op, 1))) 837 return failure(); 838 839 auto type = op->getResult(0).getType(); 840 auto elementType = getElementTypeOrSelf(type); 841 for (auto resultType : op->getResultTypes().drop_front(1)) { 842 if (getElementTypeOrSelf(resultType) != elementType || 843 failed(verifyCompatibleShape(resultType, type))) 844 return op->emitOpError() 845 << "requires the same type for all operands and results"; 846 } 847 for (auto opType : op->getOperandTypes()) { 848 if (getElementTypeOrSelf(opType) != elementType || 849 failed(verifyCompatibleShape(opType, type))) 850 return op->emitOpError() 851 << "requires the same type for all operands and results"; 852 } 853 return success(); 854 } 855 856 LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) { 857 Block *block = op->getBlock(); 858 // Verify that the operation is at the end of the respective parent block. 859 if (!block || &block->back() != op) 860 return op->emitOpError("must be the last operation in the parent block"); 861 return success(); 862 } 863 864 static LogicalResult verifyTerminatorSuccessors(Operation *op) { 865 auto *parent = op->getParentRegion(); 866 867 // Verify that the operands lines up with the BB arguments in the successor. 868 for (Block *succ : op->getSuccessors()) 869 if (succ->getParent() != parent) 870 return op->emitError("reference to block defined in another region"); 871 return success(); 872 } 873 874 LogicalResult OpTrait::impl::verifyZeroSuccessor(Operation *op) { 875 if (op->getNumSuccessors() != 0) { 876 return op->emitOpError("requires 0 successors but found ") 877 << op->getNumSuccessors(); 878 } 879 return success(); 880 } 881 882 LogicalResult OpTrait::impl::verifyOneSuccessor(Operation *op) { 883 if (op->getNumSuccessors() != 1) { 884 return op->emitOpError("requires 1 successor but found ") 885 << op->getNumSuccessors(); 886 } 887 return verifyTerminatorSuccessors(op); 888 } 889 LogicalResult OpTrait::impl::verifyNSuccessors(Operation *op, 890 unsigned numSuccessors) { 891 if (op->getNumSuccessors() != numSuccessors) { 892 return op->emitOpError("requires ") 893 << numSuccessors << " successors but found " 894 << op->getNumSuccessors(); 895 } 896 return verifyTerminatorSuccessors(op); 897 } 898 LogicalResult OpTrait::impl::verifyAtLeastNSuccessors(Operation *op, 899 unsigned numSuccessors) { 900 if (op->getNumSuccessors() < numSuccessors) { 901 return op->emitOpError("requires at least ") 902 << numSuccessors << " successors but found " 903 << op->getNumSuccessors(); 904 } 905 return verifyTerminatorSuccessors(op); 906 } 907 908 LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) { 909 for (auto resultType : op->getResultTypes()) { 910 auto elementType = getTensorOrVectorElementType(resultType); 911 bool isBoolType = elementType.isInteger(1); 912 if (!isBoolType) 913 return op->emitOpError() << "requires a bool result type"; 914 } 915 916 return success(); 917 } 918 919 LogicalResult OpTrait::impl::verifyResultsAreFloatLike(Operation *op) { 920 for (auto resultType : op->getResultTypes()) 921 if (!getTensorOrVectorElementType(resultType).isa<FloatType>()) 922 return op->emitOpError() << "requires a floating point type"; 923 924 return success(); 925 } 926 927 LogicalResult 928 OpTrait::impl::verifyResultsAreSignlessIntegerLike(Operation *op) { 929 for (auto resultType : op->getResultTypes()) 930 if (!getTensorOrVectorElementType(resultType).isSignlessIntOrIndex()) 931 return op->emitOpError() << "requires an integer or index type"; 932 return success(); 933 } 934 935 static LogicalResult verifyValueSizeAttr(Operation *op, StringRef attrName, 936 bool isOperand) { 937 auto sizeAttr = op->getAttrOfType<DenseIntElementsAttr>(attrName); 938 if (!sizeAttr) 939 return op->emitOpError("requires 1D vector attribute '") << attrName << "'"; 940 941 auto sizeAttrType = sizeAttr.getType().dyn_cast<VectorType>(); 942 if (!sizeAttrType || sizeAttrType.getRank() != 1) 943 return op->emitOpError("requires 1D vector attribute '") << attrName << "'"; 944 945 if (llvm::any_of(sizeAttr.getIntValues(), [](const APInt &element) { 946 return !element.isNonNegative(); 947 })) 948 return op->emitOpError("'") 949 << attrName << "' attribute cannot have negative elements"; 950 951 size_t totalCount = std::accumulate( 952 sizeAttr.begin(), sizeAttr.end(), 0, 953 [](unsigned all, APInt one) { return all + one.getZExtValue(); }); 954 955 if (isOperand && totalCount != op->getNumOperands()) 956 return op->emitOpError("operand count (") 957 << op->getNumOperands() << ") does not match with the total size (" 958 << totalCount << ") specified in attribute '" << attrName << "'"; 959 else if (!isOperand && totalCount != op->getNumResults()) 960 return op->emitOpError("result count (") 961 << op->getNumResults() << ") does not match with the total size (" 962 << totalCount << ") specified in attribute '" << attrName << "'"; 963 return success(); 964 } 965 966 LogicalResult OpTrait::impl::verifyOperandSizeAttr(Operation *op, 967 StringRef attrName) { 968 return verifyValueSizeAttr(op, attrName, /*isOperand=*/true); 969 } 970 971 LogicalResult OpTrait::impl::verifyResultSizeAttr(Operation *op, 972 StringRef attrName) { 973 return verifyValueSizeAttr(op, attrName, /*isOperand=*/false); 974 } 975 976 //===----------------------------------------------------------------------===// 977 // BinaryOp implementation 978 //===----------------------------------------------------------------------===// 979 980 // These functions are out-of-line implementations of the methods in BinaryOp, 981 // which avoids them being template instantiated/duplicated. 982 983 void impl::buildBinaryOp(Builder *builder, OperationState &result, Value lhs, 984 Value rhs) { 985 assert(lhs.getType() == rhs.getType()); 986 result.addOperands({lhs, rhs}); 987 result.types.push_back(lhs.getType()); 988 } 989 990 ParseResult impl::parseOneResultSameOperandTypeOp(OpAsmParser &parser, 991 OperationState &result) { 992 SmallVector<OpAsmParser::OperandType, 2> ops; 993 Type type; 994 return failure(parser.parseOperandList(ops) || 995 parser.parseOptionalAttrDict(result.attributes) || 996 parser.parseColonType(type) || 997 parser.resolveOperands(ops, type, result.operands) || 998 parser.addTypeToList(type, result.types)); 999 } 1000 1001 void impl::printOneResultOp(Operation *op, OpAsmPrinter &p) { 1002 assert(op->getNumResults() == 1 && "op should have one result"); 1003 1004 // If not all the operand and result types are the same, just use the 1005 // generic assembly form to avoid omitting information in printing. 1006 auto resultType = op->getResult(0).getType(); 1007 if (llvm::any_of(op->getOperandTypes(), 1008 [&](Type type) { return type != resultType; })) { 1009 p.printGenericOp(op); 1010 return; 1011 } 1012 1013 p << op->getName() << ' '; 1014 p.printOperands(op->getOperands()); 1015 p.printOptionalAttrDict(op->getAttrs()); 1016 // Now we can output only one type for all operands and the result. 1017 p << " : " << resultType; 1018 } 1019 1020 //===----------------------------------------------------------------------===// 1021 // CastOp implementation 1022 //===----------------------------------------------------------------------===// 1023 1024 void impl::buildCastOp(Builder *builder, OperationState &result, Value source, 1025 Type destType) { 1026 result.addOperands(source); 1027 result.addTypes(destType); 1028 } 1029 1030 ParseResult impl::parseCastOp(OpAsmParser &parser, OperationState &result) { 1031 OpAsmParser::OperandType srcInfo; 1032 Type srcType, dstType; 1033 return failure(parser.parseOperand(srcInfo) || 1034 parser.parseOptionalAttrDict(result.attributes) || 1035 parser.parseColonType(srcType) || 1036 parser.resolveOperand(srcInfo, srcType, result.operands) || 1037 parser.parseKeywordType("to", dstType) || 1038 parser.addTypeToList(dstType, result.types)); 1039 } 1040 1041 void impl::printCastOp(Operation *op, OpAsmPrinter &p) { 1042 p << op->getName() << ' ' << op->getOperand(0); 1043 p.printOptionalAttrDict(op->getAttrs()); 1044 p << " : " << op->getOperand(0).getType() << " to " 1045 << op->getResult(0).getType(); 1046 } 1047 1048 Value impl::foldCastOp(Operation *op) { 1049 // Identity cast 1050 if (op->getOperand(0).getType() == op->getResult(0).getType()) 1051 return op->getOperand(0); 1052 return nullptr; 1053 } 1054 1055 //===----------------------------------------------------------------------===// 1056 // Misc. utils 1057 //===----------------------------------------------------------------------===// 1058 1059 /// Insert an operation, generated by `buildTerminatorOp`, at the end of the 1060 /// region's only block if it does not have a terminator already. If the region 1061 /// is empty, insert a new block first. `buildTerminatorOp` should return the 1062 /// terminator operation to insert. 1063 void impl::ensureRegionTerminator( 1064 Region ®ion, Location loc, 1065 function_ref<Operation *()> buildTerminatorOp) { 1066 if (region.empty()) 1067 region.push_back(new Block); 1068 1069 Block &block = region.back(); 1070 if (!block.empty() && block.back().isKnownTerminator()) 1071 return; 1072 1073 block.push_back(buildTerminatorOp()); 1074 } 1075