1 //===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Interfaces/ControlFlowInterfaces.h"
10 #include "mlir/IR/StandardTypes.h"
11 
12 using namespace mlir;
13 
14 //===----------------------------------------------------------------------===//
15 // ControlFlowInterfaces
16 //===----------------------------------------------------------------------===//
17 
18 #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
19 
20 //===----------------------------------------------------------------------===//
21 // BranchOpInterface
22 //===----------------------------------------------------------------------===//
23 
24 /// Erase an operand from a branch operation that is used as a successor
25 /// operand. 'operandIndex' is the operand within 'operands' to be erased.
26 void mlir::detail::eraseBranchSuccessorOperand(OperandRange operands,
27                                                unsigned operandIndex,
28                                                Operation *op) {
29   assert(operandIndex < operands.size() &&
30          "invalid index for successor operands");
31 
32   // Erase the operand from the operation.
33   size_t fullOperandIndex = operands.getBeginOperandIndex() + operandIndex;
34   op->eraseOperand(fullOperandIndex);
35 
36   // If this operation has an OperandSegmentSizeAttr, keep it up to date.
37   auto operandSegmentAttr =
38       op->getAttrOfType<DenseElementsAttr>("operand_segment_sizes");
39   if (!operandSegmentAttr)
40     return;
41 
42   // Find the segment containing the full operand index and decrement it.
43   // TODO: This seems like a general utility that could be added somewhere.
44   SmallVector<int32_t, 4> values(operandSegmentAttr.getValues<int32_t>());
45   unsigned currentSize = 0;
46   for (unsigned i = 0, e = values.size(); i != e; ++i) {
47     currentSize += values[i];
48     if (fullOperandIndex < currentSize) {
49       --values[i];
50       break;
51     }
52   }
53   op->setAttr("operand_segment_sizes",
54               DenseIntElementsAttr::get(operandSegmentAttr.getType(), values));
55 }
56 
57 /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
58 /// successor if 'operandIndex' is within the range of 'operands', or None if
59 /// `operandIndex` isn't a successor operand index.
60 Optional<BlockArgument> mlir::detail::getBranchSuccessorArgument(
61     Optional<OperandRange> operands, unsigned operandIndex, Block *successor) {
62   // Check that the operands are valid.
63   if (!operands || operands->empty())
64     return llvm::None;
65 
66   // Check to ensure that this operand is within the range.
67   unsigned operandsStart = operands->getBeginOperandIndex();
68   if (operandIndex < operandsStart ||
69       operandIndex >= (operandsStart + operands->size()))
70     return llvm::None;
71 
72   // Index the successor.
73   unsigned argIndex = operandIndex - operandsStart;
74   return successor->getArgument(argIndex);
75 }
76 
77 /// Verify that the given operands match those of the given successor block.
78 LogicalResult
79 mlir::detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
80                                             Optional<OperandRange> operands) {
81   if (!operands)
82     return success();
83 
84   // Check the count.
85   unsigned operandCount = operands->size();
86   Block *destBB = op->getSuccessor(succNo);
87   if (operandCount != destBB->getNumArguments())
88     return op->emitError() << "branch has " << operandCount
89                            << " operands for successor #" << succNo
90                            << ", but target block has "
91                            << destBB->getNumArguments();
92 
93   // Check the types.
94   auto operandIt = operands->begin();
95   for (unsigned i = 0; i != operandCount; ++i, ++operandIt) {
96     if ((*operandIt).getType() != destBB->getArgument(i).getType())
97       return op->emitError() << "type mismatch for bb argument #" << i
98                              << " of successor #" << succNo;
99   }
100   return success();
101 }
102