1 //===- Operation.cpp - Operation support code -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Operation.h"
10 #include "mlir/IR/BlockAndValueMapping.h"
11 #include "mlir/IR/Dialect.h"
12 #include "mlir/IR/OpImplementation.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/IR/StandardTypes.h"
15 #include "mlir/IR/TypeUtilities.h"
16 #include "mlir/Interfaces/FoldInterfaces.h"
17 #include <numeric>
18 
19 using namespace mlir;
20 
21 OpAsmParser::~OpAsmParser() {}
22 
23 //===----------------------------------------------------------------------===//
24 // OperationName
25 //===----------------------------------------------------------------------===//
26 
27 /// Form the OperationName for an op with the specified string.  This either is
28 /// a reference to an AbstractOperation if one is known, or a uniqued Identifier
29 /// if not.
30 OperationName::OperationName(StringRef name, MLIRContext *context) {
31   if (auto *op = AbstractOperation::lookup(name, context))
32     representation = op;
33   else
34     representation = Identifier::get(name, context);
35 }
36 
37 /// Return the name of the dialect this operation is registered to.
38 StringRef OperationName::getDialect() const {
39   return getStringRef().split('.').first;
40 }
41 
42 /// Return the operation name with dialect name stripped, if it has one.
43 StringRef OperationName::stripDialect() const {
44   auto splitName = getStringRef().split(".");
45   return splitName.second.empty() ? splitName.first : splitName.second;
46 }
47 
48 /// Return the name of this operation. This always succeeds.
49 StringRef OperationName::getStringRef() const {
50   return getIdentifier().strref();
51 }
52 
53 /// Return the name of this operation as an identifier. This always succeeds.
54 Identifier OperationName::getIdentifier() const {
55   if (auto *op = representation.dyn_cast<const AbstractOperation *>())
56     return op->name;
57   return representation.get<Identifier>();
58 }
59 
60 const AbstractOperation *OperationName::getAbstractOperation() const {
61   return representation.dyn_cast<const AbstractOperation *>();
62 }
63 
64 OperationName OperationName::getFromOpaquePointer(void *pointer) {
65   return OperationName(RepresentationUnion::getFromOpaqueValue(pointer));
66 }
67 
68 //===----------------------------------------------------------------------===//
69 // Operation
70 //===----------------------------------------------------------------------===//
71 
72 /// Create a new Operation with the specific fields.
73 Operation *Operation::create(Location location, OperationName name,
74                              TypeRange resultTypes, ValueRange operands,
75                              ArrayRef<NamedAttribute> attributes,
76                              BlockRange successors, unsigned numRegions) {
77   return create(location, name, resultTypes, operands,
78                 MutableDictionaryAttr(attributes), successors, numRegions);
79 }
80 
81 /// Create a new Operation from operation state.
82 Operation *Operation::create(const OperationState &state) {
83   return create(state.location, state.name, state.types, state.operands,
84                 state.attributes, state.successors, state.regions);
85 }
86 
87 /// Create a new Operation with the specific fields.
88 Operation *Operation::create(Location location, OperationName name,
89                              TypeRange resultTypes, ValueRange operands,
90                              MutableDictionaryAttr attributes,
91                              BlockRange successors, RegionRange regions) {
92   unsigned numRegions = regions.size();
93   Operation *op = create(location, name, resultTypes, operands, attributes,
94                          successors, numRegions);
95   for (unsigned i = 0; i < numRegions; ++i)
96     if (regions[i])
97       op->getRegion(i).takeBody(*regions[i]);
98   return op;
99 }
100 
101 /// Overload of create that takes an existing MutableDictionaryAttr to avoid
102 /// unnecessarily uniquing a list of attributes.
103 Operation *Operation::create(Location location, OperationName name,
104                              TypeRange resultTypes, ValueRange operands,
105                              MutableDictionaryAttr attributes,
106                              BlockRange successors, unsigned numRegions) {
107   // We only need to allocate additional memory for a subset of results.
108   unsigned numTrailingResults = OpResult::getNumTrailing(resultTypes.size());
109   unsigned numInlineResults = OpResult::getNumInline(resultTypes.size());
110   unsigned numSuccessors = successors.size();
111   unsigned numOperands = operands.size();
112 
113   // If the operation is known to have no operands, don't allocate an operand
114   // storage.
115   bool needsOperandStorage = true;
116   if (operands.empty()) {
117     if (const AbstractOperation *abstractOp = name.getAbstractOperation())
118       needsOperandStorage = !abstractOp->hasTrait<OpTrait::ZeroOperands>();
119   }
120 
121   // Compute the byte size for the operation and the operand storage.
122   auto byteSize =
123       totalSizeToAlloc<detail::InLineOpResult, detail::TrailingOpResult,
124                        BlockOperand, Region, detail::OperandStorage>(
125           numInlineResults, numTrailingResults, numSuccessors, numRegions,
126           needsOperandStorage ? 1 : 0);
127   byteSize +=
128       llvm::alignTo(detail::OperandStorage::additionalAllocSize(numOperands),
129                     alignof(Operation));
130   void *rawMem = malloc(byteSize);
131 
132   // Create the new Operation.
133   Operation *op =
134       ::new (rawMem) Operation(location, name, resultTypes, numSuccessors,
135                                numRegions, attributes, needsOperandStorage);
136 
137   assert((numSuccessors == 0 || !op->isKnownNonTerminator()) &&
138          "unexpected successors in a non-terminator operation");
139 
140   // Initialize the results.
141   for (unsigned i = 0; i < numInlineResults; ++i)
142     new (op->getInlineResult(i)) detail::InLineOpResult();
143   for (unsigned i = 0; i < numTrailingResults; ++i)
144     new (op->getTrailingResult(i)) detail::TrailingOpResult(i);
145 
146   // Initialize the regions.
147   for (unsigned i = 0; i != numRegions; ++i)
148     new (&op->getRegion(i)) Region(op);
149 
150   // Initialize the operands.
151   if (needsOperandStorage)
152     new (&op->getOperandStorage()) detail::OperandStorage(op, operands);
153 
154   // Initialize the successors.
155   auto blockOperands = op->getBlockOperands();
156   for (unsigned i = 0; i != numSuccessors; ++i)
157     new (&blockOperands[i]) BlockOperand(op, successors[i]);
158 
159   return op;
160 }
161 
162 Operation::Operation(Location location, OperationName name,
163                      TypeRange resultTypes, unsigned numSuccessors,
164                      unsigned numRegions,
165                      const MutableDictionaryAttr &attributes,
166                      bool hasOperandStorage)
167     : location(location), numSuccs(numSuccessors), numRegions(numRegions),
168       hasOperandStorage(hasOperandStorage), hasSingleResult(false), name(name),
169       attrs(attributes) {
170   if (!resultTypes.empty()) {
171     // If there is a single result it is stored in-place, otherwise use a tuple.
172     hasSingleResult = resultTypes.size() == 1;
173     if (hasSingleResult)
174       resultType = resultTypes.front();
175     else
176       resultType = TupleType::get(resultTypes, location->getContext());
177   }
178 }
179 
180 // Operations are deleted through the destroy() member because they are
181 // allocated via malloc.
182 Operation::~Operation() {
183   assert(block == nullptr && "operation destroyed but still in a block");
184 
185   // Explicitly run the destructors for the operands.
186   if (hasOperandStorage)
187     getOperandStorage().~OperandStorage();
188 
189   // Explicitly run the destructors for the successors.
190   for (auto &successor : getBlockOperands())
191     successor.~BlockOperand();
192 
193   // Explicitly destroy the regions.
194   for (auto &region : getRegions())
195     region.~Region();
196 }
197 
198 /// Destroy this operation or one of its subclasses.
199 void Operation::destroy() {
200   this->~Operation();
201   free(this);
202 }
203 
204 /// Return the context this operation is associated with.
205 MLIRContext *Operation::getContext() { return location->getContext(); }
206 
207 /// Return the dialect this operation is associated with, or nullptr if the
208 /// associated dialect is not registered.
209 Dialect *Operation::getDialect() {
210   if (auto *abstractOp = getAbstractOperation())
211     return &abstractOp->dialect;
212 
213   // If this operation hasn't been registered or doesn't have abstract
214   // operation, try looking up the dialect name in the context.
215   return getContext()->getLoadedDialect(getName().getDialect());
216 }
217 
218 Region *Operation::getParentRegion() {
219   return block ? block->getParent() : nullptr;
220 }
221 
222 Operation *Operation::getParentOp() {
223   return block ? block->getParentOp() : nullptr;
224 }
225 
226 /// Return true if this operation is a proper ancestor of the `other`
227 /// operation.
228 bool Operation::isProperAncestor(Operation *other) {
229   while ((other = other->getParentOp()))
230     if (this == other)
231       return true;
232   return false;
233 }
234 
235 /// Replace any uses of 'from' with 'to' within this operation.
236 void Operation::replaceUsesOfWith(Value from, Value to) {
237   if (from == to)
238     return;
239   for (auto &operand : getOpOperands())
240     if (operand.get() == from)
241       operand.set(to);
242 }
243 
244 /// Replace the current operands of this operation with the ones provided in
245 /// 'operands'.
246 void Operation::setOperands(ValueRange operands) {
247   if (LLVM_LIKELY(hasOperandStorage))
248     return getOperandStorage().setOperands(this, operands);
249   assert(operands.empty() && "setting operands without an operand storage");
250 }
251 
252 /// Replace the operands beginning at 'start' and ending at 'start' + 'length'
253 /// with the ones provided in 'operands'. 'operands' may be smaller or larger
254 /// than the range pointed to by 'start'+'length'.
255 void Operation::setOperands(unsigned start, unsigned length,
256                             ValueRange operands) {
257   assert((start + length) <= getNumOperands() &&
258          "invalid operand range specified");
259   if (LLVM_LIKELY(hasOperandStorage))
260     return getOperandStorage().setOperands(this, start, length, operands);
261   assert(operands.empty() && "setting operands without an operand storage");
262 }
263 
264 /// Insert the given operands into the operand list at the given 'index'.
265 void Operation::insertOperands(unsigned index, ValueRange operands) {
266   if (LLVM_LIKELY(hasOperandStorage))
267     return setOperands(index, /*length=*/0, operands);
268   assert(operands.empty() && "inserting operands without an operand storage");
269 }
270 
271 //===----------------------------------------------------------------------===//
272 // Diagnostics
273 //===----------------------------------------------------------------------===//
274 
275 /// Emit an error about fatal conditions with this operation, reporting up to
276 /// any diagnostic handlers that may be listening.
277 InFlightDiagnostic Operation::emitError(const Twine &message) {
278   InFlightDiagnostic diag = mlir::emitError(getLoc(), message);
279   if (getContext()->shouldPrintOpOnDiagnostic()) {
280     // Print out the operation explicitly here so that we can print the generic
281     // form.
282     // TODO: It would be nice if we could instead provide the
283     // specific printing flags when adding the operation as an argument to the
284     // diagnostic.
285     std::string printedOp;
286     {
287       llvm::raw_string_ostream os(printedOp);
288       print(os, OpPrintingFlags().printGenericOpForm().useLocalScope());
289     }
290     diag.attachNote(getLoc()) << "see current operation: " << printedOp;
291   }
292   return diag;
293 }
294 
295 /// Emit a warning about this operation, reporting up to any diagnostic
296 /// handlers that may be listening.
297 InFlightDiagnostic Operation::emitWarning(const Twine &message) {
298   InFlightDiagnostic diag = mlir::emitWarning(getLoc(), message);
299   if (getContext()->shouldPrintOpOnDiagnostic())
300     diag.attachNote(getLoc()) << "see current operation: " << *this;
301   return diag;
302 }
303 
304 /// Emit a remark about this operation, reporting up to any diagnostic
305 /// handlers that may be listening.
306 InFlightDiagnostic Operation::emitRemark(const Twine &message) {
307   InFlightDiagnostic diag = mlir::emitRemark(getLoc(), message);
308   if (getContext()->shouldPrintOpOnDiagnostic())
309     diag.attachNote(getLoc()) << "see current operation: " << *this;
310   return diag;
311 }
312 
313 //===----------------------------------------------------------------------===//
314 // Operation Ordering
315 //===----------------------------------------------------------------------===//
316 
317 constexpr unsigned Operation::kInvalidOrderIdx;
318 constexpr unsigned Operation::kOrderStride;
319 
320 /// Given an operation 'other' that is within the same parent block, return
321 /// whether the current operation is before 'other' in the operation list
322 /// of the parent block.
323 /// Note: This function has an average complexity of O(1), but worst case may
324 /// take O(N) where N is the number of operations within the parent block.
325 bool Operation::isBeforeInBlock(Operation *other) {
326   assert(block && "Operations without parent blocks have no order.");
327   assert(other && other->block == block &&
328          "Expected other operation to have the same parent block.");
329   // If the order of the block is already invalid, directly recompute the
330   // parent.
331   if (!block->isOpOrderValid()) {
332     block->recomputeOpOrder();
333   } else {
334     // Update the order either operation if necessary.
335     updateOrderIfNecessary();
336     other->updateOrderIfNecessary();
337   }
338 
339   return orderIndex < other->orderIndex;
340 }
341 
342 /// Update the order index of this operation of this operation if necessary,
343 /// potentially recomputing the order of the parent block.
344 void Operation::updateOrderIfNecessary() {
345   assert(block && "expected valid parent");
346 
347   // If the order is valid for this operation there is nothing to do.
348   if (hasValidOrder())
349     return;
350   Operation *blockFront = &block->front();
351   Operation *blockBack = &block->back();
352 
353   // This method is expected to only be invoked on blocks with more than one
354   // operation.
355   assert(blockFront != blockBack && "expected more than one operation");
356 
357   // If the operation is at the end of the block.
358   if (this == blockBack) {
359     Operation *prevNode = getPrevNode();
360     if (!prevNode->hasValidOrder())
361       return block->recomputeOpOrder();
362 
363     // Add the stride to the previous operation.
364     orderIndex = prevNode->orderIndex + kOrderStride;
365     return;
366   }
367 
368   // If this is the first operation try to use the next operation to compute the
369   // ordering.
370   if (this == blockFront) {
371     Operation *nextNode = getNextNode();
372     if (!nextNode->hasValidOrder())
373       return block->recomputeOpOrder();
374     // There is no order to give this operation.
375     if (nextNode->orderIndex == 0)
376       return block->recomputeOpOrder();
377 
378     // If we can't use the stride, just take the middle value left. This is safe
379     // because we know there is at least one valid index to assign to.
380     if (nextNode->orderIndex <= kOrderStride)
381       orderIndex = (nextNode->orderIndex / 2);
382     else
383       orderIndex = kOrderStride;
384     return;
385   }
386 
387   // Otherwise, this operation is between two others. Place this operation in
388   // the middle of the previous and next if possible.
389   Operation *prevNode = getPrevNode(), *nextNode = getNextNode();
390   if (!prevNode->hasValidOrder() || !nextNode->hasValidOrder())
391     return block->recomputeOpOrder();
392   unsigned prevOrder = prevNode->orderIndex, nextOrder = nextNode->orderIndex;
393 
394   // Check to see if there is a valid order between the two.
395   if (prevOrder + 1 == nextOrder)
396     return block->recomputeOpOrder();
397   orderIndex = prevOrder + ((nextOrder - prevOrder) / 2);
398 }
399 
400 //===----------------------------------------------------------------------===//
401 // ilist_traits for Operation
402 //===----------------------------------------------------------------------===//
403 
404 auto llvm::ilist_detail::SpecificNodeAccess<
405     typename llvm::ilist_detail::compute_node_options<
406         ::mlir::Operation>::type>::getNodePtr(pointer N) -> node_type * {
407   return NodeAccess::getNodePtr<OptionsT>(N);
408 }
409 
410 auto llvm::ilist_detail::SpecificNodeAccess<
411     typename llvm::ilist_detail::compute_node_options<
412         ::mlir::Operation>::type>::getNodePtr(const_pointer N)
413     -> const node_type * {
414   return NodeAccess::getNodePtr<OptionsT>(N);
415 }
416 
417 auto llvm::ilist_detail::SpecificNodeAccess<
418     typename llvm::ilist_detail::compute_node_options<
419         ::mlir::Operation>::type>::getValuePtr(node_type *N) -> pointer {
420   return NodeAccess::getValuePtr<OptionsT>(N);
421 }
422 
423 auto llvm::ilist_detail::SpecificNodeAccess<
424     typename llvm::ilist_detail::compute_node_options<
425         ::mlir::Operation>::type>::getValuePtr(const node_type *N)
426     -> const_pointer {
427   return NodeAccess::getValuePtr<OptionsT>(N);
428 }
429 
430 void llvm::ilist_traits<::mlir::Operation>::deleteNode(Operation *op) {
431   op->destroy();
432 }
433 
434 Block *llvm::ilist_traits<::mlir::Operation>::getContainingBlock() {
435   size_t Offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr))));
436   iplist<Operation> *Anchor(static_cast<iplist<Operation> *>(this));
437   return reinterpret_cast<Block *>(reinterpret_cast<char *>(Anchor) - Offset);
438 }
439 
440 /// This is a trait method invoked when an operation is added to a block.  We
441 /// keep the block pointer up to date.
442 void llvm::ilist_traits<::mlir::Operation>::addNodeToList(Operation *op) {
443   assert(!op->getBlock() && "already in an operation block!");
444   op->block = getContainingBlock();
445 
446   // Invalidate the order on the operation.
447   op->orderIndex = Operation::kInvalidOrderIdx;
448 }
449 
450 /// This is a trait method invoked when an operation is removed from a block.
451 /// We keep the block pointer up to date.
452 void llvm::ilist_traits<::mlir::Operation>::removeNodeFromList(Operation *op) {
453   assert(op->block && "not already in an operation block!");
454   op->block = nullptr;
455 }
456 
457 /// This is a trait method invoked when an operation is moved from one block
458 /// to another.  We keep the block pointer up to date.
459 void llvm::ilist_traits<::mlir::Operation>::transferNodesFromList(
460     ilist_traits<Operation> &otherList, op_iterator first, op_iterator last) {
461   Block *curParent = getContainingBlock();
462 
463   // Invalidate the ordering of the parent block.
464   curParent->invalidateOpOrder();
465 
466   // If we are transferring operations within the same block, the block
467   // pointer doesn't need to be updated.
468   if (curParent == otherList.getContainingBlock())
469     return;
470 
471   // Update the 'block' member of each operation.
472   for (; first != last; ++first)
473     first->block = curParent;
474 }
475 
476 /// Remove this operation (and its descendants) from its Block and delete
477 /// all of them.
478 void Operation::erase() {
479   if (auto *parent = getBlock())
480     parent->getOperations().erase(this);
481   else
482     destroy();
483 }
484 
485 /// Unlink this operation from its current block and insert it right before
486 /// `existingOp` which may be in the same or another block in the same
487 /// function.
488 void Operation::moveBefore(Operation *existingOp) {
489   moveBefore(existingOp->getBlock(), existingOp->getIterator());
490 }
491 
492 /// Unlink this operation from its current basic block and insert it right
493 /// before `iterator` in the specified basic block.
494 void Operation::moveBefore(Block *block,
495                            llvm::iplist<Operation>::iterator iterator) {
496   block->getOperations().splice(iterator, getBlock()->getOperations(),
497                                 getIterator());
498 }
499 
500 /// Unlink this operation from its current block and insert it right after
501 /// `existingOp` which may be in the same or another block in the same function.
502 void Operation::moveAfter(Operation *existingOp) {
503   moveAfter(existingOp->getBlock(), existingOp->getIterator());
504 }
505 
506 /// Unlink this operation from its current block and insert it right after
507 /// `iterator` in the specified block.
508 void Operation::moveAfter(Block *block,
509                           llvm::iplist<Operation>::iterator iterator) {
510   assert(iterator != block->end() && "cannot move after end of block");
511   moveBefore(&*std::next(iterator));
512 }
513 
514 /// This drops all operand uses from this operation, which is an essential
515 /// step in breaking cyclic dependences between references when they are to
516 /// be deleted.
517 void Operation::dropAllReferences() {
518   for (auto &op : getOpOperands())
519     op.drop();
520 
521   for (auto &region : getRegions())
522     region.dropAllReferences();
523 
524   for (auto &dest : getBlockOperands())
525     dest.drop();
526 }
527 
528 /// This drops all uses of any values defined by this operation or its nested
529 /// regions, wherever they are located.
530 void Operation::dropAllDefinedValueUses() {
531   dropAllUses();
532 
533   for (auto &region : getRegions())
534     for (auto &block : region)
535       block.dropAllDefinedValueUses();
536 }
537 
538 /// Return the number of results held by this operation.
539 unsigned Operation::getNumResults() {
540   if (!resultType)
541     return 0;
542   return hasSingleResult ? 1 : resultType.cast<TupleType>().size();
543 }
544 
545 auto Operation::getResultTypes() -> result_type_range {
546   if (!resultType)
547     return llvm::None;
548   if (hasSingleResult)
549     return resultType;
550   return resultType.cast<TupleType>().getTypes();
551 }
552 
553 void Operation::setSuccessor(Block *block, unsigned index) {
554   assert(index < getNumSuccessors());
555   getBlockOperands()[index].set(block);
556 }
557 
558 /// Attempt to fold this operation using the Op's registered foldHook.
559 LogicalResult Operation::fold(ArrayRef<Attribute> operands,
560                               SmallVectorImpl<OpFoldResult> &results) {
561   // If we have a registered operation definition matching this one, use it to
562   // try to constant fold the operation.
563   auto *abstractOp = getAbstractOperation();
564   if (abstractOp && succeeded(abstractOp->foldHook(this, operands, results)))
565     return success();
566 
567   // Otherwise, fall back on the dialect hook to handle it.
568   Dialect *dialect = getDialect();
569   if (!dialect)
570     return failure();
571 
572   auto *interface = dialect->getRegisteredInterface<DialectFoldInterface>();
573   if (!interface)
574     return failure();
575 
576   return interface->fold(this, operands, results);
577 }
578 
579 /// Emit an error with the op name prefixed, like "'dim' op " which is
580 /// convenient for verifiers.
581 InFlightDiagnostic Operation::emitOpError(const Twine &message) {
582   return emitError() << "'" << getName() << "' op " << message;
583 }
584 
585 //===----------------------------------------------------------------------===//
586 // Operation Cloning
587 //===----------------------------------------------------------------------===//
588 
589 /// Create a deep copy of this operation but keep the operation regions empty.
590 /// Operands are remapped using `mapper` (if present), and `mapper` is updated
591 /// to contain the results.
592 Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper) {
593   SmallVector<Value, 8> operands;
594   SmallVector<Block *, 2> successors;
595 
596   // Remap the operands.
597   operands.reserve(getNumOperands());
598   for (auto opValue : getOperands())
599     operands.push_back(mapper.lookupOrDefault(opValue));
600 
601   // Remap the successors.
602   successors.reserve(getNumSuccessors());
603   for (Block *successor : getSuccessors())
604     successors.push_back(mapper.lookupOrDefault(successor));
605 
606   // Create the new operation.
607   auto *newOp = create(getLoc(), getName(), getResultTypes(), operands, attrs,
608                        successors, getNumRegions());
609 
610   // Remember the mapping of any results.
611   for (unsigned i = 0, e = getNumResults(); i != e; ++i)
612     mapper.map(getResult(i), newOp->getResult(i));
613 
614   return newOp;
615 }
616 
617 Operation *Operation::cloneWithoutRegions() {
618   BlockAndValueMapping mapper;
619   return cloneWithoutRegions(mapper);
620 }
621 
622 /// Create a deep copy of this operation, remapping any operands that use
623 /// values outside of the operation using the map that is provided (leaving
624 /// them alone if no entry is present).  Replaces references to cloned
625 /// sub-operations to the corresponding operation that is copied, and adds
626 /// those mappings to the map.
627 Operation *Operation::clone(BlockAndValueMapping &mapper) {
628   auto *newOp = cloneWithoutRegions(mapper);
629 
630   // Clone the regions.
631   for (unsigned i = 0; i != numRegions; ++i)
632     getRegion(i).cloneInto(&newOp->getRegion(i), mapper);
633 
634   return newOp;
635 }
636 
637 Operation *Operation::clone() {
638   BlockAndValueMapping mapper;
639   return clone(mapper);
640 }
641 
642 //===----------------------------------------------------------------------===//
643 // OpState trait class.
644 //===----------------------------------------------------------------------===//
645 
646 // The fallback for the parser is to reject the custom assembly form.
647 ParseResult OpState::parse(OpAsmParser &parser, OperationState &result) {
648   return parser.emitError(parser.getNameLoc(), "has no custom assembly form");
649 }
650 
651 // The fallback for the printer is to print in the generic assembly form.
652 void OpState::print(OpAsmPrinter &p) { p.printGenericOp(getOperation()); }
653 
654 /// Emit an error about fatal conditions with this operation, reporting up to
655 /// any diagnostic handlers that may be listening.
656 InFlightDiagnostic OpState::emitError(const Twine &message) {
657   return getOperation()->emitError(message);
658 }
659 
660 /// Emit an error with the op name prefixed, like "'dim' op " which is
661 /// convenient for verifiers.
662 InFlightDiagnostic OpState::emitOpError(const Twine &message) {
663   return getOperation()->emitOpError(message);
664 }
665 
666 /// Emit a warning about this operation, reporting up to any diagnostic
667 /// handlers that may be listening.
668 InFlightDiagnostic OpState::emitWarning(const Twine &message) {
669   return getOperation()->emitWarning(message);
670 }
671 
672 /// Emit a remark about this operation, reporting up to any diagnostic
673 /// handlers that may be listening.
674 InFlightDiagnostic OpState::emitRemark(const Twine &message) {
675   return getOperation()->emitRemark(message);
676 }
677 
678 //===----------------------------------------------------------------------===//
679 // Op Trait implementations
680 //===----------------------------------------------------------------------===//
681 
682 OpFoldResult OpTrait::impl::foldIdempotent(Operation *op) {
683   auto *argumentOp = op->getOperand(0).getDefiningOp();
684   if (argumentOp && op->getName() == argumentOp->getName()) {
685     // Replace the outer operation output with the inner operation.
686     return op->getOperand(0);
687   }
688 
689   return {};
690 }
691 
692 OpFoldResult OpTrait::impl::foldInvolution(Operation *op) {
693   auto *argumentOp = op->getOperand(0).getDefiningOp();
694   if (argumentOp && op->getName() == argumentOp->getName()) {
695     // Replace the outer involutions output with inner's input.
696     return argumentOp->getOperand(0);
697   }
698 
699   return {};
700 }
701 
702 LogicalResult OpTrait::impl::verifyZeroOperands(Operation *op) {
703   if (op->getNumOperands() != 0)
704     return op->emitOpError() << "requires zero operands";
705   return success();
706 }
707 
708 LogicalResult OpTrait::impl::verifyOneOperand(Operation *op) {
709   if (op->getNumOperands() != 1)
710     return op->emitOpError() << "requires a single operand";
711   return success();
712 }
713 
714 LogicalResult OpTrait::impl::verifyNOperands(Operation *op,
715                                              unsigned numOperands) {
716   if (op->getNumOperands() != numOperands) {
717     return op->emitOpError() << "expected " << numOperands
718                              << " operands, but found " << op->getNumOperands();
719   }
720   return success();
721 }
722 
723 LogicalResult OpTrait::impl::verifyAtLeastNOperands(Operation *op,
724                                                     unsigned numOperands) {
725   if (op->getNumOperands() < numOperands)
726     return op->emitOpError()
727            << "expected " << numOperands << " or more operands";
728   return success();
729 }
730 
731 /// If this is a vector type, or a tensor type, return the scalar element type
732 /// that it is built around, otherwise return the type unmodified.
733 static Type getTensorOrVectorElementType(Type type) {
734   if (auto vec = type.dyn_cast<VectorType>())
735     return vec.getElementType();
736 
737   // Look through tensor<vector<...>> to find the underlying element type.
738   if (auto tensor = type.dyn_cast<TensorType>())
739     return getTensorOrVectorElementType(tensor.getElementType());
740   return type;
741 }
742 
743 LogicalResult OpTrait::impl::verifyIsIdempotent(Operation *op) {
744   // FIXME: Add back check for no side effects on operation.
745   // Currently adding it would cause the shared library build
746   // to fail since there would be a dependency of IR on SideEffectInterfaces
747   // which is cyclical.
748   return success();
749 }
750 
751 LogicalResult OpTrait::impl::verifyIsInvolution(Operation *op) {
752   // FIXME: Add back check for no side effects on operation.
753   // Currently adding it would cause the shared library build
754   // to fail since there would be a dependency of IR on SideEffectInterfaces
755   // which is cyclical.
756   return success();
757 }
758 
759 LogicalResult
760 OpTrait::impl::verifyOperandsAreSignlessIntegerLike(Operation *op) {
761   for (auto opType : op->getOperandTypes()) {
762     auto type = getTensorOrVectorElementType(opType);
763     if (!type.isSignlessIntOrIndex())
764       return op->emitOpError() << "requires an integer or index type";
765   }
766   return success();
767 }
768 
769 LogicalResult OpTrait::impl::verifyOperandsAreFloatLike(Operation *op) {
770   for (auto opType : op->getOperandTypes()) {
771     auto type = getTensorOrVectorElementType(opType);
772     if (!type.isa<FloatType>())
773       return op->emitOpError("requires a float type");
774   }
775   return success();
776 }
777 
778 LogicalResult OpTrait::impl::verifySameTypeOperands(Operation *op) {
779   // Zero or one operand always have the "same" type.
780   unsigned nOperands = op->getNumOperands();
781   if (nOperands < 2)
782     return success();
783 
784   auto type = op->getOperand(0).getType();
785   for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1))
786     if (opType != type)
787       return op->emitOpError() << "requires all operands to have the same type";
788   return success();
789 }
790 
791 LogicalResult OpTrait::impl::verifyZeroRegion(Operation *op) {
792   if (op->getNumRegions() != 0)
793     return op->emitOpError() << "requires zero regions";
794   return success();
795 }
796 
797 LogicalResult OpTrait::impl::verifyOneRegion(Operation *op) {
798   if (op->getNumRegions() != 1)
799     return op->emitOpError() << "requires one region";
800   return success();
801 }
802 
803 LogicalResult OpTrait::impl::verifyNRegions(Operation *op,
804                                             unsigned numRegions) {
805   if (op->getNumRegions() != numRegions)
806     return op->emitOpError() << "expected " << numRegions << " regions";
807   return success();
808 }
809 
810 LogicalResult OpTrait::impl::verifyAtLeastNRegions(Operation *op,
811                                                    unsigned numRegions) {
812   if (op->getNumRegions() < numRegions)
813     return op->emitOpError() << "expected " << numRegions << " or more regions";
814   return success();
815 }
816 
817 LogicalResult OpTrait::impl::verifyZeroResult(Operation *op) {
818   if (op->getNumResults() != 0)
819     return op->emitOpError() << "requires zero results";
820   return success();
821 }
822 
823 LogicalResult OpTrait::impl::verifyOneResult(Operation *op) {
824   if (op->getNumResults() != 1)
825     return op->emitOpError() << "requires one result";
826   return success();
827 }
828 
829 LogicalResult OpTrait::impl::verifyNResults(Operation *op,
830                                             unsigned numOperands) {
831   if (op->getNumResults() != numOperands)
832     return op->emitOpError() << "expected " << numOperands << " results";
833   return success();
834 }
835 
836 LogicalResult OpTrait::impl::verifyAtLeastNResults(Operation *op,
837                                                    unsigned numOperands) {
838   if (op->getNumResults() < numOperands)
839     return op->emitOpError()
840            << "expected " << numOperands << " or more results";
841   return success();
842 }
843 
844 LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) {
845   if (failed(verifyAtLeastNOperands(op, 1)))
846     return failure();
847 
848   auto type = op->getOperand(0).getType();
849   for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) {
850     if (failed(verifyCompatibleShape(opType, type)))
851       return op->emitOpError() << "requires the same shape for all operands";
852   }
853   return success();
854 }
855 
856 LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) {
857   if (failed(verifyAtLeastNOperands(op, 1)) ||
858       failed(verifyAtLeastNResults(op, 1)))
859     return failure();
860 
861   auto type = op->getOperand(0).getType();
862   for (auto resultType : op->getResultTypes()) {
863     if (failed(verifyCompatibleShape(resultType, type)))
864       return op->emitOpError()
865              << "requires the same shape for all operands and results";
866   }
867   for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) {
868     if (failed(verifyCompatibleShape(opType, type)))
869       return op->emitOpError()
870              << "requires the same shape for all operands and results";
871   }
872   return success();
873 }
874 
875 LogicalResult OpTrait::impl::verifySameOperandsElementType(Operation *op) {
876   if (failed(verifyAtLeastNOperands(op, 1)))
877     return failure();
878   auto elementType = getElementTypeOrSelf(op->getOperand(0));
879 
880   for (auto operand : llvm::drop_begin(op->getOperands(), 1)) {
881     if (getElementTypeOrSelf(operand) != elementType)
882       return op->emitOpError("requires the same element type for all operands");
883   }
884 
885   return success();
886 }
887 
888 LogicalResult
889 OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) {
890   if (failed(verifyAtLeastNOperands(op, 1)) ||
891       failed(verifyAtLeastNResults(op, 1)))
892     return failure();
893 
894   auto elementType = getElementTypeOrSelf(op->getResult(0));
895 
896   // Verify result element type matches first result's element type.
897   for (auto result : llvm::drop_begin(op->getResults(), 1)) {
898     if (getElementTypeOrSelf(result) != elementType)
899       return op->emitOpError(
900           "requires the same element type for all operands and results");
901   }
902 
903   // Verify operand's element type matches first result's element type.
904   for (auto operand : op->getOperands()) {
905     if (getElementTypeOrSelf(operand) != elementType)
906       return op->emitOpError(
907           "requires the same element type for all operands and results");
908   }
909 
910   return success();
911 }
912 
913 LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) {
914   if (failed(verifyAtLeastNOperands(op, 1)) ||
915       failed(verifyAtLeastNResults(op, 1)))
916     return failure();
917 
918   auto type = op->getResult(0).getType();
919   auto elementType = getElementTypeOrSelf(type);
920   for (auto resultType : op->getResultTypes().drop_front(1)) {
921     if (getElementTypeOrSelf(resultType) != elementType ||
922         failed(verifyCompatibleShape(resultType, type)))
923       return op->emitOpError()
924              << "requires the same type for all operands and results";
925   }
926   for (auto opType : op->getOperandTypes()) {
927     if (getElementTypeOrSelf(opType) != elementType ||
928         failed(verifyCompatibleShape(opType, type)))
929       return op->emitOpError()
930              << "requires the same type for all operands and results";
931   }
932   return success();
933 }
934 
935 LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) {
936   Block *block = op->getBlock();
937   // Verify that the operation is at the end of the respective parent block.
938   if (!block || &block->back() != op)
939     return op->emitOpError("must be the last operation in the parent block");
940   return success();
941 }
942 
943 static LogicalResult verifyTerminatorSuccessors(Operation *op) {
944   auto *parent = op->getParentRegion();
945 
946   // Verify that the operands lines up with the BB arguments in the successor.
947   for (Block *succ : op->getSuccessors())
948     if (succ->getParent() != parent)
949       return op->emitError("reference to block defined in another region");
950   return success();
951 }
952 
953 LogicalResult OpTrait::impl::verifyZeroSuccessor(Operation *op) {
954   if (op->getNumSuccessors() != 0) {
955     return op->emitOpError("requires 0 successors but found ")
956            << op->getNumSuccessors();
957   }
958   return success();
959 }
960 
961 LogicalResult OpTrait::impl::verifyOneSuccessor(Operation *op) {
962   if (op->getNumSuccessors() != 1) {
963     return op->emitOpError("requires 1 successor but found ")
964            << op->getNumSuccessors();
965   }
966   return verifyTerminatorSuccessors(op);
967 }
968 LogicalResult OpTrait::impl::verifyNSuccessors(Operation *op,
969                                                unsigned numSuccessors) {
970   if (op->getNumSuccessors() != numSuccessors) {
971     return op->emitOpError("requires ")
972            << numSuccessors << " successors but found "
973            << op->getNumSuccessors();
974   }
975   return verifyTerminatorSuccessors(op);
976 }
977 LogicalResult OpTrait::impl::verifyAtLeastNSuccessors(Operation *op,
978                                                       unsigned numSuccessors) {
979   if (op->getNumSuccessors() < numSuccessors) {
980     return op->emitOpError("requires at least ")
981            << numSuccessors << " successors but found "
982            << op->getNumSuccessors();
983   }
984   return verifyTerminatorSuccessors(op);
985 }
986 
987 LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) {
988   for (auto resultType : op->getResultTypes()) {
989     auto elementType = getTensorOrVectorElementType(resultType);
990     bool isBoolType = elementType.isInteger(1);
991     if (!isBoolType)
992       return op->emitOpError() << "requires a bool result type";
993   }
994 
995   return success();
996 }
997 
998 LogicalResult OpTrait::impl::verifyResultsAreFloatLike(Operation *op) {
999   for (auto resultType : op->getResultTypes())
1000     if (!getTensorOrVectorElementType(resultType).isa<FloatType>())
1001       return op->emitOpError() << "requires a floating point type";
1002 
1003   return success();
1004 }
1005 
1006 LogicalResult
1007 OpTrait::impl::verifyResultsAreSignlessIntegerLike(Operation *op) {
1008   for (auto resultType : op->getResultTypes())
1009     if (!getTensorOrVectorElementType(resultType).isSignlessIntOrIndex())
1010       return op->emitOpError() << "requires an integer or index type";
1011   return success();
1012 }
1013 
1014 static LogicalResult verifyValueSizeAttr(Operation *op, StringRef attrName,
1015                                          bool isOperand) {
1016   auto sizeAttr = op->getAttrOfType<DenseIntElementsAttr>(attrName);
1017   if (!sizeAttr)
1018     return op->emitOpError("requires 1D vector attribute '") << attrName << "'";
1019 
1020   auto sizeAttrType = sizeAttr.getType().dyn_cast<VectorType>();
1021   if (!sizeAttrType || sizeAttrType.getRank() != 1)
1022     return op->emitOpError("requires 1D vector attribute '") << attrName << "'";
1023 
1024   if (llvm::any_of(sizeAttr.getIntValues(), [](const APInt &element) {
1025         return !element.isNonNegative();
1026       }))
1027     return op->emitOpError("'")
1028            << attrName << "' attribute cannot have negative elements";
1029 
1030   size_t totalCount = std::accumulate(
1031       sizeAttr.begin(), sizeAttr.end(), 0,
1032       [](unsigned all, APInt one) { return all + one.getZExtValue(); });
1033 
1034   if (isOperand && totalCount != op->getNumOperands())
1035     return op->emitOpError("operand count (")
1036            << op->getNumOperands() << ") does not match with the total size ("
1037            << totalCount << ") specified in attribute '" << attrName << "'";
1038   else if (!isOperand && totalCount != op->getNumResults())
1039     return op->emitOpError("result count (")
1040            << op->getNumResults() << ") does not match with the total size ("
1041            << totalCount << ") specified in attribute '" << attrName << "'";
1042   return success();
1043 }
1044 
1045 LogicalResult OpTrait::impl::verifyOperandSizeAttr(Operation *op,
1046                                                    StringRef attrName) {
1047   return verifyValueSizeAttr(op, attrName, /*isOperand=*/true);
1048 }
1049 
1050 LogicalResult OpTrait::impl::verifyResultSizeAttr(Operation *op,
1051                                                   StringRef attrName) {
1052   return verifyValueSizeAttr(op, attrName, /*isOperand=*/false);
1053 }
1054 
1055 LogicalResult OpTrait::impl::verifyNoRegionArguments(Operation *op) {
1056   for (Region &region : op->getRegions()) {
1057     if (region.empty())
1058       continue;
1059 
1060     if (region.getNumArguments() != 0) {
1061       if (op->getNumRegions() > 1)
1062         return op->emitOpError("region #")
1063                << region.getRegionNumber() << " should have no arguments";
1064       else
1065         return op->emitOpError("region should have no arguments");
1066     }
1067   }
1068   return success();
1069 }
1070 
1071 //===----------------------------------------------------------------------===//
1072 // BinaryOp implementation
1073 //===----------------------------------------------------------------------===//
1074 
1075 // These functions are out-of-line implementations of the methods in BinaryOp,
1076 // which avoids them being template instantiated/duplicated.
1077 
1078 void impl::buildBinaryOp(OpBuilder &builder, OperationState &result, Value lhs,
1079                          Value rhs) {
1080   assert(lhs.getType() == rhs.getType());
1081   result.addOperands({lhs, rhs});
1082   result.types.push_back(lhs.getType());
1083 }
1084 
1085 ParseResult impl::parseOneResultSameOperandTypeOp(OpAsmParser &parser,
1086                                                   OperationState &result) {
1087   SmallVector<OpAsmParser::OperandType, 2> ops;
1088   Type type;
1089   return failure(parser.parseOperandList(ops) ||
1090                  parser.parseOptionalAttrDict(result.attributes) ||
1091                  parser.parseColonType(type) ||
1092                  parser.resolveOperands(ops, type, result.operands) ||
1093                  parser.addTypeToList(type, result.types));
1094 }
1095 
1096 void impl::printOneResultOp(Operation *op, OpAsmPrinter &p) {
1097   assert(op->getNumResults() == 1 && "op should have one result");
1098 
1099   // If not all the operand and result types are the same, just use the
1100   // generic assembly form to avoid omitting information in printing.
1101   auto resultType = op->getResult(0).getType();
1102   if (llvm::any_of(op->getOperandTypes(),
1103                    [&](Type type) { return type != resultType; })) {
1104     p.printGenericOp(op);
1105     return;
1106   }
1107 
1108   p << op->getName() << ' ';
1109   p.printOperands(op->getOperands());
1110   p.printOptionalAttrDict(op->getAttrs());
1111   // Now we can output only one type for all operands and the result.
1112   p << " : " << resultType;
1113 }
1114 
1115 //===----------------------------------------------------------------------===//
1116 // CastOp implementation
1117 //===----------------------------------------------------------------------===//
1118 
1119 void impl::buildCastOp(OpBuilder &builder, OperationState &result, Value source,
1120                        Type destType) {
1121   result.addOperands(source);
1122   result.addTypes(destType);
1123 }
1124 
1125 ParseResult impl::parseCastOp(OpAsmParser &parser, OperationState &result) {
1126   OpAsmParser::OperandType srcInfo;
1127   Type srcType, dstType;
1128   return failure(parser.parseOperand(srcInfo) ||
1129                  parser.parseOptionalAttrDict(result.attributes) ||
1130                  parser.parseColonType(srcType) ||
1131                  parser.resolveOperand(srcInfo, srcType, result.operands) ||
1132                  parser.parseKeywordType("to", dstType) ||
1133                  parser.addTypeToList(dstType, result.types));
1134 }
1135 
1136 void impl::printCastOp(Operation *op, OpAsmPrinter &p) {
1137   p << op->getName() << ' ' << op->getOperand(0);
1138   p.printOptionalAttrDict(op->getAttrs());
1139   p << " : " << op->getOperand(0).getType() << " to "
1140     << op->getResult(0).getType();
1141 }
1142 
1143 Value impl::foldCastOp(Operation *op) {
1144   // Identity cast
1145   if (op->getOperand(0).getType() == op->getResult(0).getType())
1146     return op->getOperand(0);
1147   return nullptr;
1148 }
1149 
1150 //===----------------------------------------------------------------------===//
1151 // Misc. utils
1152 //===----------------------------------------------------------------------===//
1153 
1154 /// Insert an operation, generated by `buildTerminatorOp`, at the end of the
1155 /// region's only block if it does not have a terminator already. If the region
1156 /// is empty, insert a new block first. `buildTerminatorOp` should return the
1157 /// terminator operation to insert.
1158 void impl::ensureRegionTerminator(
1159     Region &region, OpBuilder &builder, Location loc,
1160     function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) {
1161   OpBuilder::InsertionGuard guard(builder);
1162   if (region.empty())
1163     builder.createBlock(&region);
1164 
1165   Block &block = region.back();
1166   if (!block.empty() && block.back().isKnownTerminator())
1167     return;
1168 
1169   builder.setInsertionPointToEnd(&block);
1170   builder.insert(buildTerminatorOp(builder, loc));
1171 }
1172 
1173 /// Create a simple OpBuilder and forward to the OpBuilder version of this
1174 /// function.
1175 void impl::ensureRegionTerminator(
1176     Region &region, Builder &builder, Location loc,
1177     function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) {
1178   OpBuilder opBuilder(builder.getContext());
1179   ensureRegionTerminator(region, opBuilder, loc, buildTerminatorOp);
1180 }
1181 
1182 //===----------------------------------------------------------------------===//
1183 // UseIterator
1184 //===----------------------------------------------------------------------===//
1185 
1186 Operation::UseIterator::UseIterator(Operation *op, bool end)
1187     : op(op), res(end ? op->result_end() : op->result_begin()) {
1188   // Only initialize current use if there are results/can be uses.
1189   if (op->getNumResults())
1190     skipOverResultsWithNoUsers();
1191 }
1192 
1193 Operation::UseIterator &Operation::UseIterator::operator++() {
1194   // We increment over uses, if we reach the last use then move to next
1195   // result.
1196   if (use != (*res).use_end())
1197     ++use;
1198   if (use == (*res).use_end()) {
1199     ++res;
1200     skipOverResultsWithNoUsers();
1201   }
1202   return *this;
1203 }
1204 
1205 void Operation::UseIterator::skipOverResultsWithNoUsers() {
1206   while (res != op->result_end() && (*res).use_empty())
1207     ++res;
1208 
1209   // If we are at the last result, then set use to first use of
1210   // first result (sentinel value used for end).
1211   if (res == op->result_end())
1212     use = {};
1213   else
1214     use = (*res).use_begin();
1215 }
1216