1 //===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Interfaces/ControlFlowInterfaces.h"
10 #include "mlir/IR/BuiltinTypes.h"
11 #include "llvm/ADT/SmallPtrSet.h"
12 
13 using namespace mlir;
14 
15 //===----------------------------------------------------------------------===//
16 // ControlFlowInterfaces
17 //===----------------------------------------------------------------------===//
18 
19 #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
20 
21 //===----------------------------------------------------------------------===//
22 // BranchOpInterface
23 //===----------------------------------------------------------------------===//
24 
25 /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
26 /// successor if 'operandIndex' is within the range of 'operands', or None if
27 /// `operandIndex` isn't a successor operand index.
28 Optional<BlockArgument>
29 detail::getBranchSuccessorArgument(Optional<OperandRange> operands,
30                                    unsigned operandIndex, Block *successor) {
31   // Check that the operands are valid.
32   if (!operands || operands->empty())
33     return llvm::None;
34 
35   // Check to ensure that this operand is within the range.
36   unsigned operandsStart = operands->getBeginOperandIndex();
37   if (operandIndex < operandsStart ||
38       operandIndex >= (operandsStart + operands->size()))
39     return llvm::None;
40 
41   // Index the successor.
42   unsigned argIndex = operandIndex - operandsStart;
43   return successor->getArgument(argIndex);
44 }
45 
46 /// Verify that the given operands match those of the given successor block.
47 LogicalResult
48 detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
49                                       Optional<OperandRange> operands) {
50   if (!operands)
51     return success();
52 
53   // Check the count.
54   unsigned operandCount = operands->size();
55   Block *destBB = op->getSuccessor(succNo);
56   if (operandCount != destBB->getNumArguments())
57     return op->emitError() << "branch has " << operandCount
58                            << " operands for successor #" << succNo
59                            << ", but target block has "
60                            << destBB->getNumArguments();
61 
62   // Check the types.
63   auto operandIt = operands->begin();
64   for (unsigned i = 0; i != operandCount; ++i, ++operandIt) {
65     if ((*operandIt).getType() != destBB->getArgument(i).getType())
66       return op->emitError() << "type mismatch for bb argument #" << i
67                              << " of successor #" << succNo;
68   }
69   return success();
70 }
71 
72 //===----------------------------------------------------------------------===//
73 // RegionBranchOpInterface
74 //===----------------------------------------------------------------------===//
75 
76 // A constant value to represent unknown number of region invocations.
77 const int64_t mlir::kUnknownNumRegionInvocations = -1;
78 
79 /// Verify that types match along all region control flow edges originating from
80 /// `sourceNo` (region # if source is a region, llvm::None if source is parent
81 /// op). `getInputsTypesForRegion` is a function that returns the types of the
82 /// inputs that flow from `sourceIndex' to the given region, or llvm::None if
83 /// the exact type match verification is not necessary (e.g., if the Op verifies
84 /// the match itself).
85 static LogicalResult
86 verifyTypesAlongAllEdges(Operation *op, Optional<unsigned> sourceNo,
87                          function_ref<Optional<TypeRange>(Optional<unsigned>)>
88                              getInputsTypesForRegion) {
89   auto regionInterface = cast<RegionBranchOpInterface>(op);
90 
91   SmallVector<RegionSuccessor, 2> successors;
92   unsigned numInputs;
93   if (sourceNo) {
94     Region &srcRegion = op->getRegion(sourceNo.getValue());
95     numInputs = srcRegion.getNumArguments();
96   } else {
97     numInputs = op->getNumOperands();
98   }
99   SmallVector<Attribute, 2> operands(numInputs, nullptr);
100   regionInterface.getSuccessorRegions(sourceNo, operands, successors);
101 
102   for (RegionSuccessor &succ : successors) {
103     Optional<unsigned> succRegionNo;
104     if (!succ.isParent())
105       succRegionNo = succ.getSuccessor()->getRegionNumber();
106 
107     auto printEdgeName = [&](InFlightDiagnostic &diag) -> InFlightDiagnostic & {
108       diag << "from ";
109       if (sourceNo)
110         diag << "Region #" << sourceNo.getValue();
111       else
112         diag << "parent operands";
113 
114       diag << " to ";
115       if (succRegionNo)
116         diag << "Region #" << succRegionNo.getValue();
117       else
118         diag << "parent results";
119       return diag;
120     };
121 
122     Optional<TypeRange> sourceTypes = getInputsTypesForRegion(succRegionNo);
123     if (!sourceTypes.hasValue())
124       continue;
125 
126     TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
127     if (sourceTypes->size() != succInputsTypes.size()) {
128       InFlightDiagnostic diag = op->emitOpError(" region control flow edge ");
129       return printEdgeName(diag) << ": source has " << sourceTypes->size()
130                                  << " operands, but target successor needs "
131                                  << succInputsTypes.size();
132     }
133 
134     for (auto typesIdx :
135          llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) {
136       Type sourceType = std::get<0>(typesIdx.value());
137       Type inputType = std::get<1>(typesIdx.value());
138       if (sourceType != inputType) {
139         InFlightDiagnostic diag = op->emitOpError(" along control flow edge ");
140         return printEdgeName(diag)
141                << ": source type #" << typesIdx.index() << " " << sourceType
142                << " should match input type #" << typesIdx.index() << " "
143                << inputType;
144       }
145     }
146   }
147   return success();
148 }
149 
150 /// Verify that types match along control flow edges described the given op.
151 LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
152   auto regionInterface = cast<RegionBranchOpInterface>(op);
153 
154   auto inputTypesFromParent = [&](Optional<unsigned> regionNo) -> TypeRange {
155     if (regionNo.hasValue()) {
156       return regionInterface.getSuccessorEntryOperands(regionNo.getValue())
157           .getTypes();
158     }
159 
160     // If the successor of a parent op is the parent itself
161     // RegionBranchOpInterface does not have an API to query what the entry
162     // operands will be in that case. Vend out the result types of the op in
163     // that case so that type checking succeeds for this case.
164     return op->getResultTypes();
165   };
166 
167   // Verify types along control flow edges originating from the parent.
168   if (failed(verifyTypesAlongAllEdges(op, llvm::None, inputTypesFromParent)))
169     return failure();
170 
171   // RegionBranchOpInterface should not be implemented by Ops that do not have
172   // attached regions.
173   assert(op->getNumRegions() != 0);
174 
175   // Verify types along control flow edges originating from each region.
176   for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) {
177     Region &region = op->getRegion(regionNo);
178 
179     // Since there can be multiple `ReturnLike` terminators or others
180     // implementing the `RegionBranchTerminatorOpInterface`, all should have the
181     // same operand types when passing them to the same region.
182 
183     Optional<OperandRange> regionReturnOperands;
184     for (Block &block : region) {
185       Operation *terminator = block.getTerminator();
186       auto terminatorOperands =
187           getRegionBranchSuccessorOperands(terminator, regionNo);
188       if (!terminatorOperands)
189         continue;
190 
191       if (!regionReturnOperands) {
192         regionReturnOperands = terminatorOperands;
193         continue;
194       }
195 
196       // Found more than one ReturnLike terminator. Make sure the operand types
197       // match with the first one.
198       if (regionReturnOperands->getTypes() != terminatorOperands->getTypes())
199         return op->emitOpError("Region #")
200                << regionNo
201                << " operands mismatch between return-like terminators";
202     }
203 
204     auto inputTypesFromRegion =
205         [&](Optional<unsigned> regionNo) -> Optional<TypeRange> {
206       // If there is no return-like terminator, the op itself should verify
207       // type consistency.
208       if (!regionReturnOperands)
209         return llvm::None;
210 
211       // All successors get the same set of operand types.
212       return TypeRange(regionReturnOperands->getTypes());
213     };
214 
215     if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesFromRegion)))
216       return failure();
217   }
218 
219   return success();
220 }
221 
222 //===----------------------------------------------------------------------===//
223 // RegionBranchTerminatorOpInterface
224 //===----------------------------------------------------------------------===//
225 
226 /// Returns true if the given operation is either annotated with the
227 /// `ReturnLike` trait or implements the `RegionBranchTerminatorOpInterface`.
228 bool mlir::isRegionReturnLike(Operation *operation) {
229   return dyn_cast<RegionBranchTerminatorOpInterface>(operation) ||
230          operation->hasTrait<OpTrait::ReturnLike>();
231 }
232 
233 /// Returns the mutable operands that are passed to the region with the given
234 /// `regionIndex`. If the operation does not implement the
235 /// `RegionBranchTerminatorOpInterface` and is not marked as `ReturnLike`, the
236 /// result will be `llvm::None`. In all other cases, the resulting
237 /// `OperandRange` represents all operands that are passed to the specified
238 /// successor region. If `regionIndex` is `llvm::None`, all operands that are
239 /// passed to the parent operation will be returned.
240 Optional<MutableOperandRange>
241 mlir::getMutableRegionBranchSuccessorOperands(Operation *operation,
242                                               Optional<unsigned> regionIndex) {
243   // Try to query a RegionBranchTerminatorOpInterface to determine
244   // all successor operands that will be passed to the successor
245   // input arguments.
246   if (auto regionTerminatorInterface =
247           dyn_cast<RegionBranchTerminatorOpInterface>(operation))
248     return regionTerminatorInterface.getMutableSuccessorOperands(regionIndex);
249 
250   // TODO: The ReturnLike trait should imply a default implementation of the
251   // RegionBranchTerminatorOpInterface. This would make this code significantly
252   // easier. Furthermore, this may even make this function obsolete.
253   if (operation->hasTrait<OpTrait::ReturnLike>())
254     return MutableOperandRange(operation);
255   return llvm::None;
256 }
257 
258 /// Returns the read only operands that are passed to the region with the given
259 /// `regionIndex`. See `getMutableRegionBranchSuccessorOperands` for more
260 /// information.
261 Optional<OperandRange>
262 mlir::getRegionBranchSuccessorOperands(Operation *operation,
263                                        Optional<unsigned> regionIndex) {
264   auto range = getMutableRegionBranchSuccessorOperands(operation, regionIndex);
265   return range ? Optional<OperandRange>(*range) : llvm::None;
266 }
267