1 //===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Interfaces/ControlFlowInterfaces.h"
10 #include "mlir/IR/BuiltinTypes.h"
11 #include "llvm/ADT/SmallPtrSet.h"
12 
13 using namespace mlir;
14 
15 //===----------------------------------------------------------------------===//
16 // ControlFlowInterfaces
17 //===----------------------------------------------------------------------===//
18 
19 #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
20 
21 //===----------------------------------------------------------------------===//
22 // BranchOpInterface
23 //===----------------------------------------------------------------------===//
24 
25 /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
26 /// successor if 'operandIndex' is within the range of 'operands', or None if
27 /// `operandIndex` isn't a successor operand index.
28 Optional<BlockArgument>
29 detail::getBranchSuccessorArgument(Optional<OperandRange> operands,
30                                    unsigned operandIndex, Block *successor) {
31   // Check that the operands are valid.
32   if (!operands || operands->empty())
33     return llvm::None;
34 
35   // Check to ensure that this operand is within the range.
36   unsigned operandsStart = operands->getBeginOperandIndex();
37   if (operandIndex < operandsStart ||
38       operandIndex >= (operandsStart + operands->size()))
39     return llvm::None;
40 
41   // Index the successor.
42   unsigned argIndex = operandIndex - operandsStart;
43   return successor->getArgument(argIndex);
44 }
45 
46 /// Verify that the given operands match those of the given successor block.
47 LogicalResult
48 detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
49                                       Optional<OperandRange> operands) {
50   if (!operands)
51     return success();
52 
53   // Check the count.
54   unsigned operandCount = operands->size();
55   Block *destBB = op->getSuccessor(succNo);
56   if (operandCount != destBB->getNumArguments())
57     return op->emitError() << "branch has " << operandCount
58                            << " operands for successor #" << succNo
59                            << ", but target block has "
60                            << destBB->getNumArguments();
61 
62   // Check the types.
63   auto operandIt = operands->begin();
64   for (unsigned i = 0; i != operandCount; ++i, ++operandIt) {
65     if (!cast<BranchOpInterface>(op).areTypesCompatible(
66             (*operandIt).getType(), destBB->getArgument(i).getType()))
67       return op->emitError() << "type mismatch for bb argument #" << i
68                              << " of successor #" << succNo;
69   }
70   return success();
71 }
72 
73 //===----------------------------------------------------------------------===//
74 // RegionBranchOpInterface
75 //===----------------------------------------------------------------------===//
76 
77 /// Verify that types match along all region control flow edges originating from
78 /// `sourceNo` (region # if source is a region, llvm::None if source is parent
79 /// op). `getInputsTypesForRegion` is a function that returns the types of the
80 /// inputs that flow from `sourceIndex' to the given region, or llvm::None if
81 /// the exact type match verification is not necessary (e.g., if the Op verifies
82 /// the match itself).
83 static LogicalResult
84 verifyTypesAlongAllEdges(Operation *op, Optional<unsigned> sourceNo,
85                          function_ref<Optional<TypeRange>(Optional<unsigned>)>
86                              getInputsTypesForRegion) {
87   auto regionInterface = cast<RegionBranchOpInterface>(op);
88 
89   SmallVector<RegionSuccessor, 2> successors;
90   unsigned numInputs;
91   if (sourceNo) {
92     Region &srcRegion = op->getRegion(sourceNo.getValue());
93     numInputs = srcRegion.getNumArguments();
94   } else {
95     numInputs = op->getNumOperands();
96   }
97   SmallVector<Attribute, 2> operands(numInputs, nullptr);
98   regionInterface.getSuccessorRegions(sourceNo, operands, successors);
99 
100   for (RegionSuccessor &succ : successors) {
101     Optional<unsigned> succRegionNo;
102     if (!succ.isParent())
103       succRegionNo = succ.getSuccessor()->getRegionNumber();
104 
105     auto printEdgeName = [&](InFlightDiagnostic &diag) -> InFlightDiagnostic & {
106       diag << "from ";
107       if (sourceNo)
108         diag << "Region #" << sourceNo.getValue();
109       else
110         diag << "parent operands";
111 
112       diag << " to ";
113       if (succRegionNo)
114         diag << "Region #" << succRegionNo.getValue();
115       else
116         diag << "parent results";
117       return diag;
118     };
119 
120     Optional<TypeRange> sourceTypes = getInputsTypesForRegion(succRegionNo);
121     if (!sourceTypes.hasValue())
122       continue;
123 
124     TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
125     if (sourceTypes->size() != succInputsTypes.size()) {
126       InFlightDiagnostic diag = op->emitOpError(" region control flow edge ");
127       return printEdgeName(diag) << ": source has " << sourceTypes->size()
128                                  << " operands, but target successor needs "
129                                  << succInputsTypes.size();
130     }
131 
132     for (const auto &typesIdx :
133          llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) {
134       Type sourceType = std::get<0>(typesIdx.value());
135       Type inputType = std::get<1>(typesIdx.value());
136       if (!regionInterface.areTypesCompatible(sourceType, inputType)) {
137         InFlightDiagnostic diag = op->emitOpError(" along control flow edge ");
138         return printEdgeName(diag)
139                << ": source type #" << typesIdx.index() << " " << sourceType
140                << " should match input type #" << typesIdx.index() << " "
141                << inputType;
142       }
143     }
144   }
145   return success();
146 }
147 
148 /// Verify that types match along control flow edges described the given op.
149 LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
150   auto regionInterface = cast<RegionBranchOpInterface>(op);
151 
152   auto inputTypesFromParent = [&](Optional<unsigned> regionNo) -> TypeRange {
153     if (regionNo.hasValue()) {
154       return regionInterface.getSuccessorEntryOperands(regionNo.getValue())
155           .getTypes();
156     }
157 
158     // If the successor of a parent op is the parent itself
159     // RegionBranchOpInterface does not have an API to query what the entry
160     // operands will be in that case. Vend out the result types of the op in
161     // that case so that type checking succeeds for this case.
162     return op->getResultTypes();
163   };
164 
165   // Verify types along control flow edges originating from the parent.
166   if (failed(verifyTypesAlongAllEdges(op, llvm::None, inputTypesFromParent)))
167     return failure();
168 
169   // RegionBranchOpInterface should not be implemented by Ops that do not have
170   // attached regions.
171   assert(op->getNumRegions() != 0);
172 
173   auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) {
174     if (lhs.size() != rhs.size())
175       return false;
176     for (auto types : llvm::zip(lhs, rhs)) {
177       if (!regionInterface.areTypesCompatible(std::get<0>(types),
178                                               std::get<1>(types))) {
179         return false;
180       }
181     }
182     return true;
183   };
184 
185   // Verify types along control flow edges originating from each region.
186   for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) {
187     Region &region = op->getRegion(regionNo);
188 
189     // Since there can be multiple `ReturnLike` terminators or others
190     // implementing the `RegionBranchTerminatorOpInterface`, all should have the
191     // same operand types when passing them to the same region.
192 
193     Optional<OperandRange> regionReturnOperands;
194     for (Block &block : region) {
195       Operation *terminator = block.getTerminator();
196       auto terminatorOperands =
197           getRegionBranchSuccessorOperands(terminator, regionNo);
198       if (!terminatorOperands)
199         continue;
200 
201       if (!regionReturnOperands) {
202         regionReturnOperands = terminatorOperands;
203         continue;
204       }
205 
206       // Found more than one ReturnLike terminator. Make sure the operand types
207       // match with the first one.
208       if (!areTypesCompatible(regionReturnOperands->getTypes(),
209                               terminatorOperands->getTypes()))
210         return op->emitOpError("Region #")
211                << regionNo
212                << " operands mismatch between return-like terminators";
213     }
214 
215     auto inputTypesFromRegion =
216         [&](Optional<unsigned> regionNo) -> Optional<TypeRange> {
217       // If there is no return-like terminator, the op itself should verify
218       // type consistency.
219       if (!regionReturnOperands)
220         return llvm::None;
221 
222       // All successors get the same set of operand types.
223       return TypeRange(regionReturnOperands->getTypes());
224     };
225 
226     if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesFromRegion)))
227       return failure();
228   }
229 
230   return success();
231 }
232 
233 /// Return `true` if `a` and `b` are in mutually exclusive regions.
234 ///
235 /// 1. Find the first common of `a` and `b` (ancestor) that implements
236 ///    RegionBranchOpInterface.
237 /// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are
238 ///    contained.
239 /// 3. Check if `regionA` and `regionB` are mutually exclusive. They are
240 ///    mutually exclusive if they are not reachable from each other as per
241 ///    RegionBranchOpInterface::getSuccessorRegions.
242 bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) {
243   assert(a && "expected non-empty operation");
244   assert(b && "expected non-empty operation");
245 
246   auto branchOp = a->getParentOfType<RegionBranchOpInterface>();
247   while (branchOp) {
248     // Check if b is inside branchOp. (We already know that a is.)
249     if (!branchOp->isProperAncestor(b)) {
250       // Check next enclosing RegionBranchOpInterface.
251       branchOp = branchOp->getParentOfType<RegionBranchOpInterface>();
252       continue;
253     }
254 
255     // b is contained in branchOp. Retrieve the regions in which `a` and `b`
256     // are contained.
257     Region *regionA = nullptr, *regionB = nullptr;
258     for (Region &r : branchOp->getRegions()) {
259       if (r.findAncestorOpInRegion(*a)) {
260         assert(!regionA && "already found a region for a");
261         regionA = &r;
262       }
263       if (r.findAncestorOpInRegion(*b)) {
264         assert(!regionB && "already found a region for b");
265         regionB = &r;
266       }
267     }
268     assert(regionA && regionB && "could not find region of op");
269 
270     // Helper function that checks if region `r` is reachable from region
271     // `begin`.
272     std::function<bool(Region *, Region *)> isRegionReachable =
273         [&](Region *begin, Region *r) {
274           if (begin == r)
275             return true;
276           if (begin == nullptr)
277             return false;
278           // Compute index of region.
279           int64_t beginIndex = -1;
280           for (const auto &it : llvm::enumerate(branchOp->getRegions()))
281             if (&it.value() == begin)
282               beginIndex = it.index();
283           assert(beginIndex != -1 && "could not find region in op");
284           // Retrieve all successors of the region.
285           SmallVector<RegionSuccessor> successors;
286           branchOp.getSuccessorRegions(beginIndex, successors);
287           // Call function recursively on all successors.
288           for (RegionSuccessor successor : successors)
289             if (isRegionReachable(successor.getSuccessor(), r))
290               return true;
291           return false;
292         };
293 
294     // `a` and `b` are in mutually exclusive regions if neither region is
295     // reachable from the other region.
296     return !isRegionReachable(regionA, regionB) &&
297            !isRegionReachable(regionB, regionA);
298   }
299 
300   // Could not find a common RegionBranchOpInterface among a's and b's
301   // ancestors.
302   return false;
303 }
304 
305 //===----------------------------------------------------------------------===//
306 // RegionBranchTerminatorOpInterface
307 //===----------------------------------------------------------------------===//
308 
309 /// Returns true if the given operation is either annotated with the
310 /// `ReturnLike` trait or implements the `RegionBranchTerminatorOpInterface`.
311 bool mlir::isRegionReturnLike(Operation *operation) {
312   return dyn_cast<RegionBranchTerminatorOpInterface>(operation) ||
313          operation->hasTrait<OpTrait::ReturnLike>();
314 }
315 
316 /// Returns the mutable operands that are passed to the region with the given
317 /// `regionIndex`. If the operation does not implement the
318 /// `RegionBranchTerminatorOpInterface` and is not marked as `ReturnLike`, the
319 /// result will be `llvm::None`. In all other cases, the resulting
320 /// `OperandRange` represents all operands that are passed to the specified
321 /// successor region. If `regionIndex` is `llvm::None`, all operands that are
322 /// passed to the parent operation will be returned.
323 Optional<MutableOperandRange>
324 mlir::getMutableRegionBranchSuccessorOperands(Operation *operation,
325                                               Optional<unsigned> regionIndex) {
326   // Try to query a RegionBranchTerminatorOpInterface to determine
327   // all successor operands that will be passed to the successor
328   // input arguments.
329   if (auto regionTerminatorInterface =
330           dyn_cast<RegionBranchTerminatorOpInterface>(operation))
331     return regionTerminatorInterface.getMutableSuccessorOperands(regionIndex);
332 
333   // TODO: The ReturnLike trait should imply a default implementation of the
334   // RegionBranchTerminatorOpInterface. This would make this code significantly
335   // easier. Furthermore, this may even make this function obsolete.
336   if (operation->hasTrait<OpTrait::ReturnLike>())
337     return MutableOperandRange(operation);
338   return llvm::None;
339 }
340 
341 /// Returns the read only operands that are passed to the region with the given
342 /// `regionIndex`. See `getMutableRegionBranchSuccessorOperands` for more
343 /// information.
344 Optional<OperandRange>
345 mlir::getRegionBranchSuccessorOperands(Operation *operation,
346                                        Optional<unsigned> regionIndex) {
347   auto range = getMutableRegionBranchSuccessorOperands(operation, regionIndex);
348   return range ? Optional<OperandRange>(*range) : llvm::None;
349 }
350