1 //===- Operation.cpp - Operation support code -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Operation.h"
10 #include "mlir/IR/BlockAndValueMapping.h"
11 #include "mlir/IR/Dialect.h"
12 #include "mlir/IR/OpImplementation.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/IR/StandardTypes.h"
15 #include "mlir/IR/TypeUtilities.h"
16 #include <numeric>
17 
18 using namespace mlir;
19 
20 OpAsmParser::~OpAsmParser() {}
21 
22 //===----------------------------------------------------------------------===//
23 // OperationName
24 //===----------------------------------------------------------------------===//
25 
26 /// Form the OperationName for an op with the specified string.  This either is
27 /// a reference to an AbstractOperation if one is known, or a uniqued Identifier
28 /// if not.
29 OperationName::OperationName(StringRef name, MLIRContext *context) {
30   if (auto *op = AbstractOperation::lookup(name, context))
31     representation = op;
32   else
33     representation = Identifier::get(name, context);
34 }
35 
36 /// Return the name of the dialect this operation is registered to.
37 StringRef OperationName::getDialect() const {
38   return getStringRef().split('.').first;
39 }
40 
41 /// Return the name of this operation.  This always succeeds.
42 StringRef OperationName::getStringRef() const {
43   if (auto *op = representation.dyn_cast<const AbstractOperation *>())
44     return op->name;
45   return representation.get<Identifier>().strref();
46 }
47 
48 const AbstractOperation *OperationName::getAbstractOperation() const {
49   return representation.dyn_cast<const AbstractOperation *>();
50 }
51 
52 OperationName OperationName::getFromOpaquePointer(void *pointer) {
53   return OperationName(RepresentationUnion::getFromOpaqueValue(pointer));
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // Operation
58 //===----------------------------------------------------------------------===//
59 
60 /// Create a new Operation with the specific fields.
61 Operation *Operation::create(Location location, OperationName name,
62                              ArrayRef<Type> resultTypes,
63                              ArrayRef<Value> operands,
64                              ArrayRef<NamedAttribute> attributes,
65                              ArrayRef<Block *> successors, unsigned numRegions,
66                              bool resizableOperandList) {
67   return create(location, name, resultTypes, operands,
68                 NamedAttributeList(attributes), successors, numRegions,
69                 resizableOperandList);
70 }
71 
72 /// Create a new Operation from operation state.
73 Operation *Operation::create(const OperationState &state) {
74   return Operation::create(state.location, state.name, state.types,
75                            state.operands, NamedAttributeList(state.attributes),
76                            state.successors, state.regions,
77                            state.resizableOperandList);
78 }
79 
80 /// Create a new Operation with the specific fields.
81 Operation *Operation::create(Location location, OperationName name,
82                              ArrayRef<Type> resultTypes,
83                              ArrayRef<Value> operands,
84                              NamedAttributeList attributes,
85                              ArrayRef<Block *> successors, RegionRange regions,
86                              bool resizableOperandList) {
87   unsigned numRegions = regions.size();
88   Operation *op = create(location, name, resultTypes, operands, attributes,
89                          successors, numRegions, resizableOperandList);
90   for (unsigned i = 0; i < numRegions; ++i)
91     if (regions[i])
92       op->getRegion(i).takeBody(*regions[i]);
93   return op;
94 }
95 
96 /// Overload of create that takes an existing NamedAttributeList to avoid
97 /// unnecessarily uniquing a list of attributes.
98 Operation *Operation::create(Location location, OperationName name,
99                              ArrayRef<Type> resultTypes,
100                              ArrayRef<Value> operands,
101                              NamedAttributeList attributes,
102                              ArrayRef<Block *> successors, unsigned numRegions,
103                              bool resizableOperandList) {
104   // We only need to allocate additional memory for a subset of results.
105   unsigned numTrailingResults = OpResult::getNumTrailing(resultTypes.size());
106   unsigned numSuccessors = successors.size();
107   unsigned numOperands = operands.size();
108 
109   // Compute the byte size for the operation and the operand storage.
110   auto byteSize = totalSizeToAlloc<detail::TrailingOpResult, BlockOperand,
111                                    Region, detail::OperandStorage>(
112       numTrailingResults, numSuccessors, numRegions,
113       /*detail::OperandStorage*/ 1);
114   byteSize += llvm::alignTo(detail::OperandStorage::additionalAllocSize(
115                                 numOperands, resizableOperandList),
116                             alignof(Operation));
117   void *rawMem = malloc(byteSize);
118 
119   // Create the new Operation.
120   auto op = ::new (rawMem) Operation(location, name, resultTypes, numSuccessors,
121                                      numRegions, attributes);
122 
123   assert((numSuccessors == 0 || !op->isKnownNonTerminator()) &&
124          "unexpected successors in a non-terminator operation");
125 
126   // Initialize the trailing results.
127   if (LLVM_UNLIKELY(numTrailingResults > 0)) {
128     // We initialize the trailing results with their result number. This makes
129     // 'getResultNumber' checks much more efficient. The main purpose for these
130     // results is to give an anchor to the main operation anyways, so this is
131     // purely an optimization.
132     auto *trailingResultIt = op->getTrailingObjects<detail::TrailingOpResult>();
133     for (unsigned i = 0; i != numTrailingResults; ++i, ++trailingResultIt)
134       trailingResultIt->trailingResultNumber = i;
135   }
136 
137   // Initialize the regions.
138   for (unsigned i = 0; i != numRegions; ++i)
139     new (&op->getRegion(i)) Region(op);
140 
141   // Initialize the operands.
142   new (&op->getOperandStorage())
143       detail::OperandStorage(numOperands, resizableOperandList);
144   auto opOperands = op->getOpOperands();
145   for (unsigned i = 0; i != numOperands; ++i)
146     new (&opOperands[i]) OpOperand(op, operands[i]);
147 
148   // Initialize the successors.
149   auto blockOperands = op->getBlockOperands();
150   for (unsigned i = 0; i != numSuccessors; ++i)
151     new (&blockOperands[i]) BlockOperand(op, successors[i]);
152 
153   return op;
154 }
155 
156 Operation::Operation(Location location, OperationName name,
157                      ArrayRef<Type> resultTypes, unsigned numSuccessors,
158                      unsigned numRegions, const NamedAttributeList &attributes)
159     : location(location), numSuccs(numSuccessors), numRegions(numRegions),
160       hasSingleResult(false), name(name), attrs(attributes) {
161   if (!resultTypes.empty()) {
162     // If there is a single result it is stored in-place, otherwise use a tuple.
163     hasSingleResult = resultTypes.size() == 1;
164     if (hasSingleResult)
165       resultType = resultTypes.front();
166     else
167       resultType = TupleType::get(resultTypes, location->getContext());
168   }
169 }
170 
171 // Operations are deleted through the destroy() member because they are
172 // allocated via malloc.
173 Operation::~Operation() {
174   assert(block == nullptr && "operation destroyed but still in a block");
175 
176   // Explicitly run the destructors for the operands and results.
177   getOperandStorage().~OperandStorage();
178 
179   // Explicitly run the destructors for the successors.
180   for (auto &successor : getBlockOperands())
181     successor.~BlockOperand();
182 
183   // Explicitly destroy the regions.
184   for (auto &region : getRegions())
185     region.~Region();
186 }
187 
188 /// Destroy this operation or one of its subclasses.
189 void Operation::destroy() {
190   this->~Operation();
191   free(this);
192 }
193 
194 /// Return the context this operation is associated with.
195 MLIRContext *Operation::getContext() { return location->getContext(); }
196 
197 /// Return the dialect this operation is associated with, or nullptr if the
198 /// associated dialect is not registered.
199 Dialect *Operation::getDialect() {
200   if (auto *abstractOp = getAbstractOperation())
201     return &abstractOp->dialect;
202 
203   // If this operation hasn't been registered or doesn't have abstract
204   // operation, try looking up the dialect name in the context.
205   return getContext()->getRegisteredDialect(getName().getDialect());
206 }
207 
208 Region *Operation::getParentRegion() {
209   return block ? block->getParent() : nullptr;
210 }
211 
212 Operation *Operation::getParentOp() {
213   return block ? block->getParentOp() : nullptr;
214 }
215 
216 /// Return true if this operation is a proper ancestor of the `other`
217 /// operation.
218 bool Operation::isProperAncestor(Operation *other) {
219   while ((other = other->getParentOp()))
220     if (this == other)
221       return true;
222   return false;
223 }
224 
225 /// Replace any uses of 'from' with 'to' within this operation.
226 void Operation::replaceUsesOfWith(Value from, Value to) {
227   if (from == to)
228     return;
229   for (auto &operand : getOpOperands())
230     if (operand.get() == from)
231       operand.set(to);
232 }
233 
234 /// Replace the current operands of this operation with the ones provided in
235 /// 'operands'. If the operands list is not resizable, the size of 'operands'
236 /// must be less than or equal to the current number of operands.
237 void Operation::setOperands(ValueRange operands) {
238   getOperandStorage().setOperands(this, operands);
239 }
240 
241 //===----------------------------------------------------------------------===//
242 // Diagnostics
243 //===----------------------------------------------------------------------===//
244 
245 /// Emit an error about fatal conditions with this operation, reporting up to
246 /// any diagnostic handlers that may be listening.
247 InFlightDiagnostic Operation::emitError(const Twine &message) {
248   InFlightDiagnostic diag = mlir::emitError(getLoc(), message);
249   if (getContext()->shouldPrintOpOnDiagnostic()) {
250     // Print out the operation explicitly here so that we can print the generic
251     // form.
252     // TODO(riverriddle) It would be nice if we could instead provide the
253     // specific printing flags when adding the operation as an argument to the
254     // diagnostic.
255     std::string printedOp;
256     {
257       llvm::raw_string_ostream os(printedOp);
258       print(os, OpPrintingFlags().printGenericOpForm().useLocalScope());
259     }
260     diag.attachNote(getLoc()) << "see current operation: " << printedOp;
261   }
262   return diag;
263 }
264 
265 /// Emit a warning about this operation, reporting up to any diagnostic
266 /// handlers that may be listening.
267 InFlightDiagnostic Operation::emitWarning(const Twine &message) {
268   InFlightDiagnostic diag = mlir::emitWarning(getLoc(), message);
269   if (getContext()->shouldPrintOpOnDiagnostic())
270     diag.attachNote(getLoc()) << "see current operation: " << *this;
271   return diag;
272 }
273 
274 /// Emit a remark about this operation, reporting up to any diagnostic
275 /// handlers that may be listening.
276 InFlightDiagnostic Operation::emitRemark(const Twine &message) {
277   InFlightDiagnostic diag = mlir::emitRemark(getLoc(), message);
278   if (getContext()->shouldPrintOpOnDiagnostic())
279     diag.attachNote(getLoc()) << "see current operation: " << *this;
280   return diag;
281 }
282 
283 //===----------------------------------------------------------------------===//
284 // Operation Ordering
285 //===----------------------------------------------------------------------===//
286 
287 constexpr unsigned Operation::kInvalidOrderIdx;
288 constexpr unsigned Operation::kOrderStride;
289 
290 /// Given an operation 'other' that is within the same parent block, return
291 /// whether the current operation is before 'other' in the operation list
292 /// of the parent block.
293 /// Note: This function has an average complexity of O(1), but worst case may
294 /// take O(N) where N is the number of operations within the parent block.
295 bool Operation::isBeforeInBlock(Operation *other) {
296   assert(block && "Operations without parent blocks have no order.");
297   assert(other && other->block == block &&
298          "Expected other operation to have the same parent block.");
299   // If the order of the block is already invalid, directly recompute the
300   // parent.
301   if (!block->isOpOrderValid()) {
302     block->recomputeOpOrder();
303   } else {
304     // Update the order either operation if necessary.
305     updateOrderIfNecessary();
306     other->updateOrderIfNecessary();
307   }
308 
309   return orderIndex < other->orderIndex;
310 }
311 
312 /// Update the order index of this operation of this operation if necessary,
313 /// potentially recomputing the order of the parent block.
314 void Operation::updateOrderIfNecessary() {
315   assert(block && "expected valid parent");
316 
317   // If the order is valid for this operation there is nothing to do.
318   if (hasValidOrder())
319     return;
320   Operation *blockFront = &block->front();
321   Operation *blockBack = &block->back();
322 
323   // This method is expected to only be invoked on blocks with more than one
324   // operation.
325   assert(blockFront != blockBack && "expected more than one operation");
326 
327   // If the operation is at the end of the block.
328   if (this == blockBack) {
329     Operation *prevNode = getPrevNode();
330     if (!prevNode->hasValidOrder())
331       return block->recomputeOpOrder();
332 
333     // Add the stride to the previous operation.
334     orderIndex = prevNode->orderIndex + kOrderStride;
335     return;
336   }
337 
338   // If this is the first operation try to use the next operation to compute the
339   // ordering.
340   if (this == blockFront) {
341     Operation *nextNode = getNextNode();
342     if (!nextNode->hasValidOrder())
343       return block->recomputeOpOrder();
344     // There is no order to give this operation.
345     if (nextNode->orderIndex == 0)
346       return block->recomputeOpOrder();
347 
348     // If we can't use the stride, just take the middle value left. This is safe
349     // because we know there is at least one valid index to assign to.
350     if (nextNode->orderIndex <= kOrderStride)
351       orderIndex = (nextNode->orderIndex / 2);
352     else
353       orderIndex = kOrderStride;
354     return;
355   }
356 
357   // Otherwise, this operation is between two others. Place this operation in
358   // the middle of the previous and next if possible.
359   Operation *prevNode = getPrevNode(), *nextNode = getNextNode();
360   if (!prevNode->hasValidOrder() || !nextNode->hasValidOrder())
361     return block->recomputeOpOrder();
362   unsigned prevOrder = prevNode->orderIndex, nextOrder = nextNode->orderIndex;
363 
364   // Check to see if there is a valid order between the two.
365   if (prevOrder + 1 == nextOrder)
366     return block->recomputeOpOrder();
367   orderIndex = prevOrder + 1 + ((nextOrder - prevOrder) / 2);
368 }
369 
370 //===----------------------------------------------------------------------===//
371 // ilist_traits for Operation
372 //===----------------------------------------------------------------------===//
373 
374 auto llvm::ilist_detail::SpecificNodeAccess<
375     typename llvm::ilist_detail::compute_node_options<
376         ::mlir::Operation>::type>::getNodePtr(pointer N) -> node_type * {
377   return NodeAccess::getNodePtr<OptionsT>(N);
378 }
379 
380 auto llvm::ilist_detail::SpecificNodeAccess<
381     typename llvm::ilist_detail::compute_node_options<
382         ::mlir::Operation>::type>::getNodePtr(const_pointer N)
383     -> const node_type * {
384   return NodeAccess::getNodePtr<OptionsT>(N);
385 }
386 
387 auto llvm::ilist_detail::SpecificNodeAccess<
388     typename llvm::ilist_detail::compute_node_options<
389         ::mlir::Operation>::type>::getValuePtr(node_type *N) -> pointer {
390   return NodeAccess::getValuePtr<OptionsT>(N);
391 }
392 
393 auto llvm::ilist_detail::SpecificNodeAccess<
394     typename llvm::ilist_detail::compute_node_options<
395         ::mlir::Operation>::type>::getValuePtr(const node_type *N)
396     -> const_pointer {
397   return NodeAccess::getValuePtr<OptionsT>(N);
398 }
399 
400 void llvm::ilist_traits<::mlir::Operation>::deleteNode(Operation *op) {
401   op->destroy();
402 }
403 
404 Block *llvm::ilist_traits<::mlir::Operation>::getContainingBlock() {
405   size_t Offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr))));
406   iplist<Operation> *Anchor(static_cast<iplist<Operation> *>(this));
407   return reinterpret_cast<Block *>(reinterpret_cast<char *>(Anchor) - Offset);
408 }
409 
410 /// This is a trait method invoked when an operation is added to a block.  We
411 /// keep the block pointer up to date.
412 void llvm::ilist_traits<::mlir::Operation>::addNodeToList(Operation *op) {
413   assert(!op->getBlock() && "already in an operation block!");
414   op->block = getContainingBlock();
415 
416   // Invalidate the order on the operation.
417   op->orderIndex = Operation::kInvalidOrderIdx;
418 }
419 
420 /// This is a trait method invoked when an operation is removed from a block.
421 /// We keep the block pointer up to date.
422 void llvm::ilist_traits<::mlir::Operation>::removeNodeFromList(Operation *op) {
423   assert(op->block && "not already in an operation block!");
424   op->block = nullptr;
425 }
426 
427 /// This is a trait method invoked when an operation is moved from one block
428 /// to another.  We keep the block pointer up to date.
429 void llvm::ilist_traits<::mlir::Operation>::transferNodesFromList(
430     ilist_traits<Operation> &otherList, op_iterator first, op_iterator last) {
431   Block *curParent = getContainingBlock();
432 
433   // Invalidate the ordering of the parent block.
434   curParent->invalidateOpOrder();
435 
436   // If we are transferring operations within the same block, the block
437   // pointer doesn't need to be updated.
438   if (curParent == otherList.getContainingBlock())
439     return;
440 
441   // Update the 'block' member of each operation.
442   for (; first != last; ++first)
443     first->block = curParent;
444 }
445 
446 /// Remove this operation (and its descendants) from its Block and delete
447 /// all of them.
448 void Operation::erase() {
449   if (auto *parent = getBlock())
450     parent->getOperations().erase(this);
451   else
452     destroy();
453 }
454 
455 /// Unlink this operation from its current block and insert it right before
456 /// `existingOp` which may be in the same or another block in the same
457 /// function.
458 void Operation::moveBefore(Operation *existingOp) {
459   moveBefore(existingOp->getBlock(), existingOp->getIterator());
460 }
461 
462 /// Unlink this operation from its current basic block and insert it right
463 /// before `iterator` in the specified basic block.
464 void Operation::moveBefore(Block *block,
465                            llvm::iplist<Operation>::iterator iterator) {
466   block->getOperations().splice(iterator, getBlock()->getOperations(),
467                                 getIterator());
468 }
469 
470 /// This drops all operand uses from this operation, which is an essential
471 /// step in breaking cyclic dependences between references when they are to
472 /// be deleted.
473 void Operation::dropAllReferences() {
474   for (auto &op : getOpOperands())
475     op.drop();
476 
477   for (auto &region : getRegions())
478     region.dropAllReferences();
479 
480   for (auto &dest : getBlockOperands())
481     dest.drop();
482 }
483 
484 /// This drops all uses of any values defined by this operation or its nested
485 /// regions, wherever they are located.
486 void Operation::dropAllDefinedValueUses() {
487   dropAllUses();
488 
489   for (auto &region : getRegions())
490     for (auto &block : region)
491       block.dropAllDefinedValueUses();
492 }
493 
494 /// Return the number of results held by this operation.
495 unsigned Operation::getNumResults() {
496   if (!resultType)
497     return 0;
498   return hasSingleResult ? 1 : resultType.cast<TupleType>().size();
499 }
500 
501 auto Operation::getResultTypes() -> result_type_range {
502   if (!resultType)
503     return llvm::None;
504   if (hasSingleResult)
505     return resultType;
506   return resultType.cast<TupleType>().getTypes();
507 }
508 
509 void Operation::setSuccessor(Block *block, unsigned index) {
510   assert(index < getNumSuccessors());
511   getBlockOperands()[index].set(block);
512 }
513 
514 /// Attempt to fold this operation using the Op's registered foldHook.
515 LogicalResult Operation::fold(ArrayRef<Attribute> operands,
516                               SmallVectorImpl<OpFoldResult> &results) {
517   // If we have a registered operation definition matching this one, use it to
518   // try to constant fold the operation.
519   auto *abstractOp = getAbstractOperation();
520   if (abstractOp && succeeded(abstractOp->foldHook(this, operands, results)))
521     return success();
522 
523   // Otherwise, fall back on the dialect hook to handle it.
524   Dialect *dialect = getDialect();
525   if (!dialect)
526     return failure();
527 
528   SmallVector<Attribute, 8> constants;
529   if (failed(dialect->constantFoldHook(this, operands, constants)))
530     return failure();
531   results.assign(constants.begin(), constants.end());
532   return success();
533 }
534 
535 /// Emit an error with the op name prefixed, like "'dim' op " which is
536 /// convenient for verifiers.
537 InFlightDiagnostic Operation::emitOpError(const Twine &message) {
538   return emitError() << "'" << getName() << "' op " << message;
539 }
540 
541 //===----------------------------------------------------------------------===//
542 // Operation Cloning
543 //===----------------------------------------------------------------------===//
544 
545 /// Create a deep copy of this operation but keep the operation regions empty.
546 /// Operands are remapped using `mapper` (if present), and `mapper` is updated
547 /// to contain the results.
548 Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper) {
549   SmallVector<Value, 8> operands;
550   SmallVector<Block *, 2> successors;
551 
552   // Remap the operands.
553   operands.reserve(getNumOperands());
554   for (auto opValue : getOperands())
555     operands.push_back(mapper.lookupOrDefault(opValue));
556 
557   // Remap the successors.
558   successors.reserve(getNumSuccessors());
559   for (Block *successor : getSuccessors())
560     successors.push_back(mapper.lookupOrDefault(successor));
561 
562   // Create the new operation.
563   auto *newOp = Operation::create(getLoc(), getName(), getResultTypes(),
564                                   operands, attrs, successors, getNumRegions(),
565                                   hasResizableOperandsList());
566 
567   // Remember the mapping of any results.
568   for (unsigned i = 0, e = getNumResults(); i != e; ++i)
569     mapper.map(getResult(i), newOp->getResult(i));
570 
571   return newOp;
572 }
573 
574 Operation *Operation::cloneWithoutRegions() {
575   BlockAndValueMapping mapper;
576   return cloneWithoutRegions(mapper);
577 }
578 
579 /// Create a deep copy of this operation, remapping any operands that use
580 /// values outside of the operation using the map that is provided (leaving
581 /// them alone if no entry is present).  Replaces references to cloned
582 /// sub-operations to the corresponding operation that is copied, and adds
583 /// those mappings to the map.
584 Operation *Operation::clone(BlockAndValueMapping &mapper) {
585   auto *newOp = cloneWithoutRegions(mapper);
586 
587   // Clone the regions.
588   for (unsigned i = 0; i != numRegions; ++i)
589     getRegion(i).cloneInto(&newOp->getRegion(i), mapper);
590 
591   return newOp;
592 }
593 
594 Operation *Operation::clone() {
595   BlockAndValueMapping mapper;
596   return clone(mapper);
597 }
598 
599 //===----------------------------------------------------------------------===//
600 // OpState trait class.
601 //===----------------------------------------------------------------------===//
602 
603 // The fallback for the parser is to reject the custom assembly form.
604 ParseResult OpState::parse(OpAsmParser &parser, OperationState &result) {
605   return parser.emitError(parser.getNameLoc(), "has no custom assembly form");
606 }
607 
608 // The fallback for the printer is to print in the generic assembly form.
609 void OpState::print(OpAsmPrinter &p) { p.printGenericOp(getOperation()); }
610 
611 /// Emit an error about fatal conditions with this operation, reporting up to
612 /// any diagnostic handlers that may be listening.
613 InFlightDiagnostic OpState::emitError(const Twine &message) {
614   return getOperation()->emitError(message);
615 }
616 
617 /// Emit an error with the op name prefixed, like "'dim' op " which is
618 /// convenient for verifiers.
619 InFlightDiagnostic OpState::emitOpError(const Twine &message) {
620   return getOperation()->emitOpError(message);
621 }
622 
623 /// Emit a warning about this operation, reporting up to any diagnostic
624 /// handlers that may be listening.
625 InFlightDiagnostic OpState::emitWarning(const Twine &message) {
626   return getOperation()->emitWarning(message);
627 }
628 
629 /// Emit a remark about this operation, reporting up to any diagnostic
630 /// handlers that may be listening.
631 InFlightDiagnostic OpState::emitRemark(const Twine &message) {
632   return getOperation()->emitRemark(message);
633 }
634 
635 //===----------------------------------------------------------------------===//
636 // Op Trait implementations
637 //===----------------------------------------------------------------------===//
638 
639 LogicalResult OpTrait::impl::verifyZeroOperands(Operation *op) {
640   if (op->getNumOperands() != 0)
641     return op->emitOpError() << "requires zero operands";
642   return success();
643 }
644 
645 LogicalResult OpTrait::impl::verifyOneOperand(Operation *op) {
646   if (op->getNumOperands() != 1)
647     return op->emitOpError() << "requires a single operand";
648   return success();
649 }
650 
651 LogicalResult OpTrait::impl::verifyNOperands(Operation *op,
652                                              unsigned numOperands) {
653   if (op->getNumOperands() != numOperands) {
654     return op->emitOpError() << "expected " << numOperands
655                              << " operands, but found " << op->getNumOperands();
656   }
657   return success();
658 }
659 
660 LogicalResult OpTrait::impl::verifyAtLeastNOperands(Operation *op,
661                                                     unsigned numOperands) {
662   if (op->getNumOperands() < numOperands)
663     return op->emitOpError()
664            << "expected " << numOperands << " or more operands";
665   return success();
666 }
667 
668 /// If this is a vector type, or a tensor type, return the scalar element type
669 /// that it is built around, otherwise return the type unmodified.
670 static Type getTensorOrVectorElementType(Type type) {
671   if (auto vec = type.dyn_cast<VectorType>())
672     return vec.getElementType();
673 
674   // Look through tensor<vector<...>> to find the underlying element type.
675   if (auto tensor = type.dyn_cast<TensorType>())
676     return getTensorOrVectorElementType(tensor.getElementType());
677   return type;
678 }
679 
680 LogicalResult
681 OpTrait::impl::verifyOperandsAreSignlessIntegerLike(Operation *op) {
682   for (auto opType : op->getOperandTypes()) {
683     auto type = getTensorOrVectorElementType(opType);
684     if (!type.isSignlessIntOrIndex())
685       return op->emitOpError() << "requires an integer or index type";
686   }
687   return success();
688 }
689 
690 LogicalResult OpTrait::impl::verifyOperandsAreFloatLike(Operation *op) {
691   for (auto opType : op->getOperandTypes()) {
692     auto type = getTensorOrVectorElementType(opType);
693     if (!type.isa<FloatType>())
694       return op->emitOpError("requires a float type");
695   }
696   return success();
697 }
698 
699 LogicalResult OpTrait::impl::verifySameTypeOperands(Operation *op) {
700   // Zero or one operand always have the "same" type.
701   unsigned nOperands = op->getNumOperands();
702   if (nOperands < 2)
703     return success();
704 
705   auto type = op->getOperand(0).getType();
706   for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1))
707     if (opType != type)
708       return op->emitOpError() << "requires all operands to have the same type";
709   return success();
710 }
711 
712 LogicalResult OpTrait::impl::verifyZeroRegion(Operation *op) {
713   if (op->getNumRegions() != 0)
714     return op->emitOpError() << "requires zero regions";
715   return success();
716 }
717 
718 LogicalResult OpTrait::impl::verifyOneRegion(Operation *op) {
719   if (op->getNumRegions() != 1)
720     return op->emitOpError() << "requires one region";
721   return success();
722 }
723 
724 LogicalResult OpTrait::impl::verifyNRegions(Operation *op,
725                                             unsigned numRegions) {
726   if (op->getNumRegions() != numRegions)
727     return op->emitOpError() << "expected " << numRegions << " regions";
728   return success();
729 }
730 
731 LogicalResult OpTrait::impl::verifyAtLeastNRegions(Operation *op,
732                                                    unsigned numRegions) {
733   if (op->getNumRegions() < numRegions)
734     return op->emitOpError() << "expected " << numRegions << " or more regions";
735   return success();
736 }
737 
738 LogicalResult OpTrait::impl::verifyZeroResult(Operation *op) {
739   if (op->getNumResults() != 0)
740     return op->emitOpError() << "requires zero results";
741   return success();
742 }
743 
744 LogicalResult OpTrait::impl::verifyOneResult(Operation *op) {
745   if (op->getNumResults() != 1)
746     return op->emitOpError() << "requires one result";
747   return success();
748 }
749 
750 LogicalResult OpTrait::impl::verifyNResults(Operation *op,
751                                             unsigned numOperands) {
752   if (op->getNumResults() != numOperands)
753     return op->emitOpError() << "expected " << numOperands << " results";
754   return success();
755 }
756 
757 LogicalResult OpTrait::impl::verifyAtLeastNResults(Operation *op,
758                                                    unsigned numOperands) {
759   if (op->getNumResults() < numOperands)
760     return op->emitOpError()
761            << "expected " << numOperands << " or more results";
762   return success();
763 }
764 
765 LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) {
766   if (failed(verifyAtLeastNOperands(op, 1)))
767     return failure();
768 
769   auto type = op->getOperand(0).getType();
770   for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) {
771     if (failed(verifyCompatibleShape(opType, type)))
772       return op->emitOpError() << "requires the same shape for all operands";
773   }
774   return success();
775 }
776 
777 LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) {
778   if (failed(verifyAtLeastNOperands(op, 1)) ||
779       failed(verifyAtLeastNResults(op, 1)))
780     return failure();
781 
782   auto type = op->getOperand(0).getType();
783   for (auto resultType : op->getResultTypes()) {
784     if (failed(verifyCompatibleShape(resultType, type)))
785       return op->emitOpError()
786              << "requires the same shape for all operands and results";
787   }
788   for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) {
789     if (failed(verifyCompatibleShape(opType, type)))
790       return op->emitOpError()
791              << "requires the same shape for all operands and results";
792   }
793   return success();
794 }
795 
796 LogicalResult OpTrait::impl::verifySameOperandsElementType(Operation *op) {
797   if (failed(verifyAtLeastNOperands(op, 1)))
798     return failure();
799   auto elementType = getElementTypeOrSelf(op->getOperand(0));
800 
801   for (auto operand : llvm::drop_begin(op->getOperands(), 1)) {
802     if (getElementTypeOrSelf(operand) != elementType)
803       return op->emitOpError("requires the same element type for all operands");
804   }
805 
806   return success();
807 }
808 
809 LogicalResult
810 OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) {
811   if (failed(verifyAtLeastNOperands(op, 1)) ||
812       failed(verifyAtLeastNResults(op, 1)))
813     return failure();
814 
815   auto elementType = getElementTypeOrSelf(op->getResult(0));
816 
817   // Verify result element type matches first result's element type.
818   for (auto result : llvm::drop_begin(op->getResults(), 1)) {
819     if (getElementTypeOrSelf(result) != elementType)
820       return op->emitOpError(
821           "requires the same element type for all operands and results");
822   }
823 
824   // Verify operand's element type matches first result's element type.
825   for (auto operand : op->getOperands()) {
826     if (getElementTypeOrSelf(operand) != elementType)
827       return op->emitOpError(
828           "requires the same element type for all operands and results");
829   }
830 
831   return success();
832 }
833 
834 LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) {
835   if (failed(verifyAtLeastNOperands(op, 1)) ||
836       failed(verifyAtLeastNResults(op, 1)))
837     return failure();
838 
839   auto type = op->getResult(0).getType();
840   auto elementType = getElementTypeOrSelf(type);
841   for (auto resultType : op->getResultTypes().drop_front(1)) {
842     if (getElementTypeOrSelf(resultType) != elementType ||
843         failed(verifyCompatibleShape(resultType, type)))
844       return op->emitOpError()
845              << "requires the same type for all operands and results";
846   }
847   for (auto opType : op->getOperandTypes()) {
848     if (getElementTypeOrSelf(opType) != elementType ||
849         failed(verifyCompatibleShape(opType, type)))
850       return op->emitOpError()
851              << "requires the same type for all operands and results";
852   }
853   return success();
854 }
855 
856 LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) {
857   Block *block = op->getBlock();
858   // Verify that the operation is at the end of the respective parent block.
859   if (!block || &block->back() != op)
860     return op->emitOpError("must be the last operation in the parent block");
861   return success();
862 }
863 
864 static LogicalResult verifyTerminatorSuccessors(Operation *op) {
865   auto *parent = op->getParentRegion();
866 
867   // Verify that the operands lines up with the BB arguments in the successor.
868   for (Block *succ : op->getSuccessors())
869     if (succ->getParent() != parent)
870       return op->emitError("reference to block defined in another region");
871   return success();
872 }
873 
874 LogicalResult OpTrait::impl::verifyZeroSuccessor(Operation *op) {
875   if (op->getNumSuccessors() != 0) {
876     return op->emitOpError("requires 0 successors but found ")
877            << op->getNumSuccessors();
878   }
879   return success();
880 }
881 
882 LogicalResult OpTrait::impl::verifyOneSuccessor(Operation *op) {
883   if (op->getNumSuccessors() != 1) {
884     return op->emitOpError("requires 1 successor but found ")
885            << op->getNumSuccessors();
886   }
887   return verifyTerminatorSuccessors(op);
888 }
889 LogicalResult OpTrait::impl::verifyNSuccessors(Operation *op,
890                                                unsigned numSuccessors) {
891   if (op->getNumSuccessors() != numSuccessors) {
892     return op->emitOpError("requires ")
893            << numSuccessors << " successors but found "
894            << op->getNumSuccessors();
895   }
896   return verifyTerminatorSuccessors(op);
897 }
898 LogicalResult OpTrait::impl::verifyAtLeastNSuccessors(Operation *op,
899                                                       unsigned numSuccessors) {
900   if (op->getNumSuccessors() < numSuccessors) {
901     return op->emitOpError("requires at least ")
902            << numSuccessors << " successors but found "
903            << op->getNumSuccessors();
904   }
905   return verifyTerminatorSuccessors(op);
906 }
907 
908 LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) {
909   for (auto resultType : op->getResultTypes()) {
910     auto elementType = getTensorOrVectorElementType(resultType);
911     bool isBoolType = elementType.isInteger(1);
912     if (!isBoolType)
913       return op->emitOpError() << "requires a bool result type";
914   }
915 
916   return success();
917 }
918 
919 LogicalResult OpTrait::impl::verifyResultsAreFloatLike(Operation *op) {
920   for (auto resultType : op->getResultTypes())
921     if (!getTensorOrVectorElementType(resultType).isa<FloatType>())
922       return op->emitOpError() << "requires a floating point type";
923 
924   return success();
925 }
926 
927 LogicalResult
928 OpTrait::impl::verifyResultsAreSignlessIntegerLike(Operation *op) {
929   for (auto resultType : op->getResultTypes())
930     if (!getTensorOrVectorElementType(resultType).isSignlessIntOrIndex())
931       return op->emitOpError() << "requires an integer or index type";
932   return success();
933 }
934 
935 static LogicalResult verifyValueSizeAttr(Operation *op, StringRef attrName,
936                                          bool isOperand) {
937   auto sizeAttr = op->getAttrOfType<DenseIntElementsAttr>(attrName);
938   if (!sizeAttr)
939     return op->emitOpError("requires 1D vector attribute '") << attrName << "'";
940 
941   auto sizeAttrType = sizeAttr.getType().dyn_cast<VectorType>();
942   if (!sizeAttrType || sizeAttrType.getRank() != 1)
943     return op->emitOpError("requires 1D vector attribute '") << attrName << "'";
944 
945   if (llvm::any_of(sizeAttr.getIntValues(), [](const APInt &element) {
946         return !element.isNonNegative();
947       }))
948     return op->emitOpError("'")
949            << attrName << "' attribute cannot have negative elements";
950 
951   size_t totalCount = std::accumulate(
952       sizeAttr.begin(), sizeAttr.end(), 0,
953       [](unsigned all, APInt one) { return all + one.getZExtValue(); });
954 
955   if (isOperand && totalCount != op->getNumOperands())
956     return op->emitOpError("operand count (")
957            << op->getNumOperands() << ") does not match with the total size ("
958            << totalCount << ") specified in attribute '" << attrName << "'";
959   else if (!isOperand && totalCount != op->getNumResults())
960     return op->emitOpError("result count (")
961            << op->getNumResults() << ") does not match with the total size ("
962            << totalCount << ") specified in attribute '" << attrName << "'";
963   return success();
964 }
965 
966 LogicalResult OpTrait::impl::verifyOperandSizeAttr(Operation *op,
967                                                    StringRef attrName) {
968   return verifyValueSizeAttr(op, attrName, /*isOperand=*/true);
969 }
970 
971 LogicalResult OpTrait::impl::verifyResultSizeAttr(Operation *op,
972                                                   StringRef attrName) {
973   return verifyValueSizeAttr(op, attrName, /*isOperand=*/false);
974 }
975 
976 //===----------------------------------------------------------------------===//
977 // BinaryOp implementation
978 //===----------------------------------------------------------------------===//
979 
980 // These functions are out-of-line implementations of the methods in BinaryOp,
981 // which avoids them being template instantiated/duplicated.
982 
983 void impl::buildBinaryOp(Builder *builder, OperationState &result, Value lhs,
984                          Value rhs) {
985   assert(lhs.getType() == rhs.getType());
986   result.addOperands({lhs, rhs});
987   result.types.push_back(lhs.getType());
988 }
989 
990 ParseResult impl::parseOneResultSameOperandTypeOp(OpAsmParser &parser,
991                                                   OperationState &result) {
992   SmallVector<OpAsmParser::OperandType, 2> ops;
993   Type type;
994   return failure(parser.parseOperandList(ops) ||
995                  parser.parseOptionalAttrDict(result.attributes) ||
996                  parser.parseColonType(type) ||
997                  parser.resolveOperands(ops, type, result.operands) ||
998                  parser.addTypeToList(type, result.types));
999 }
1000 
1001 void impl::printOneResultOp(Operation *op, OpAsmPrinter &p) {
1002   assert(op->getNumResults() == 1 && "op should have one result");
1003 
1004   // If not all the operand and result types are the same, just use the
1005   // generic assembly form to avoid omitting information in printing.
1006   auto resultType = op->getResult(0).getType();
1007   if (llvm::any_of(op->getOperandTypes(),
1008                    [&](Type type) { return type != resultType; })) {
1009     p.printGenericOp(op);
1010     return;
1011   }
1012 
1013   p << op->getName() << ' ';
1014   p.printOperands(op->getOperands());
1015   p.printOptionalAttrDict(op->getAttrs());
1016   // Now we can output only one type for all operands and the result.
1017   p << " : " << resultType;
1018 }
1019 
1020 //===----------------------------------------------------------------------===//
1021 // CastOp implementation
1022 //===----------------------------------------------------------------------===//
1023 
1024 void impl::buildCastOp(Builder *builder, OperationState &result, Value source,
1025                        Type destType) {
1026   result.addOperands(source);
1027   result.addTypes(destType);
1028 }
1029 
1030 ParseResult impl::parseCastOp(OpAsmParser &parser, OperationState &result) {
1031   OpAsmParser::OperandType srcInfo;
1032   Type srcType, dstType;
1033   return failure(parser.parseOperand(srcInfo) ||
1034                  parser.parseOptionalAttrDict(result.attributes) ||
1035                  parser.parseColonType(srcType) ||
1036                  parser.resolveOperand(srcInfo, srcType, result.operands) ||
1037                  parser.parseKeywordType("to", dstType) ||
1038                  parser.addTypeToList(dstType, result.types));
1039 }
1040 
1041 void impl::printCastOp(Operation *op, OpAsmPrinter &p) {
1042   p << op->getName() << ' ' << op->getOperand(0);
1043   p.printOptionalAttrDict(op->getAttrs());
1044   p << " : " << op->getOperand(0).getType() << " to "
1045     << op->getResult(0).getType();
1046 }
1047 
1048 Value impl::foldCastOp(Operation *op) {
1049   // Identity cast
1050   if (op->getOperand(0).getType() == op->getResult(0).getType())
1051     return op->getOperand(0);
1052   return nullptr;
1053 }
1054 
1055 //===----------------------------------------------------------------------===//
1056 // Misc. utils
1057 //===----------------------------------------------------------------------===//
1058 
1059 /// Insert an operation, generated by `buildTerminatorOp`, at the end of the
1060 /// region's only block if it does not have a terminator already. If the region
1061 /// is empty, insert a new block first. `buildTerminatorOp` should return the
1062 /// terminator operation to insert.
1063 void impl::ensureRegionTerminator(
1064     Region &region, Location loc,
1065     function_ref<Operation *()> buildTerminatorOp) {
1066   if (region.empty())
1067     region.push_back(new Block);
1068 
1069   Block &block = region.back();
1070   if (!block.empty() && block.back().isKnownTerminator())
1071     return;
1072 
1073   block.push_back(buildTerminatorOp());
1074 }
1075