1 //===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/Interfaces/ControlFlowInterfaces.h"
12 #include "llvm/ADT/SmallPtrSet.h"
13 
14 using namespace mlir;
15 
16 //===----------------------------------------------------------------------===//
17 // ControlFlowInterfaces
18 //===----------------------------------------------------------------------===//
19 
20 #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
21 
22 SuccessorOperands::SuccessorOperands(MutableOperandRange forwardedOperands)
23     : producedOperandCount(0), forwardedOperands(std::move(forwardedOperands)) {
24 }
25 
26 SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount,
27                                      MutableOperandRange forwardedOperands)
28     : producedOperandCount(producedOperandCount),
29       forwardedOperands(std::move(forwardedOperands)) {}
30 
31 //===----------------------------------------------------------------------===//
32 // BranchOpInterface
33 //===----------------------------------------------------------------------===//
34 
35 /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
36 /// successor if 'operandIndex' is within the range of 'operands', or None if
37 /// `operandIndex` isn't a successor operand index.
38 Optional<BlockArgument>
39 detail::getBranchSuccessorArgument(const SuccessorOperands &operands,
40                                    unsigned operandIndex, Block *successor) {
41   OperandRange forwardedOperands = operands.getForwardedOperands();
42   // Check that the operands are valid.
43   if (forwardedOperands.empty())
44     return llvm::None;
45 
46   // Check to ensure that this operand is within the range.
47   unsigned operandsStart = forwardedOperands.getBeginOperandIndex();
48   if (operandIndex < operandsStart ||
49       operandIndex >= (operandsStart + forwardedOperands.size()))
50     return llvm::None;
51 
52   // Index the successor.
53   unsigned argIndex =
54       operands.getProducedOperandCount() + operandIndex - operandsStart;
55   return successor->getArgument(argIndex);
56 }
57 
58 /// Verify that the given operands match those of the given successor block.
59 LogicalResult
60 detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
61                                       const SuccessorOperands &operands) {
62   // Check the count.
63   unsigned operandCount = operands.size();
64   Block *destBB = op->getSuccessor(succNo);
65   if (operandCount != destBB->getNumArguments())
66     return op->emitError() << "branch has " << operandCount
67                            << " operands for successor #" << succNo
68                            << ", but target block has "
69                            << destBB->getNumArguments();
70 
71   // Check the types.
72   for (unsigned i = operands.getProducedOperandCount(); i != operandCount;
73        ++i) {
74     if (!cast<BranchOpInterface>(op).areTypesCompatible(
75             operands[i].getType(), destBB->getArgument(i).getType()))
76       return op->emitError() << "type mismatch for bb argument #" << i
77                              << " of successor #" << succNo;
78   }
79   return success();
80 }
81 
82 //===----------------------------------------------------------------------===//
83 // RegionBranchOpInterface
84 //===----------------------------------------------------------------------===//
85 
86 /// Verify that types match along all region control flow edges originating from
87 /// `sourceNo` (region # if source is a region, llvm::None if source is parent
88 /// op). `getInputsTypesForRegion` is a function that returns the types of the
89 /// inputs that flow from `sourceIndex' to the given region, or llvm::None if
90 /// the exact type match verification is not necessary (e.g., if the Op verifies
91 /// the match itself).
92 static LogicalResult
93 verifyTypesAlongAllEdges(Operation *op, Optional<unsigned> sourceNo,
94                          function_ref<Optional<TypeRange>(Optional<unsigned>)>
95                              getInputsTypesForRegion) {
96   auto regionInterface = cast<RegionBranchOpInterface>(op);
97 
98   SmallVector<RegionSuccessor, 2> successors;
99   regionInterface.getSuccessorRegions(sourceNo, successors);
100 
101   for (RegionSuccessor &succ : successors) {
102     Optional<unsigned> succRegionNo;
103     if (!succ.isParent())
104       succRegionNo = succ.getSuccessor()->getRegionNumber();
105 
106     auto printEdgeName = [&](InFlightDiagnostic &diag) -> InFlightDiagnostic & {
107       diag << "from ";
108       if (sourceNo)
109         diag << "Region #" << sourceNo.getValue();
110       else
111         diag << "parent operands";
112 
113       diag << " to ";
114       if (succRegionNo)
115         diag << "Region #" << succRegionNo.getValue();
116       else
117         diag << "parent results";
118       return diag;
119     };
120 
121     Optional<TypeRange> sourceTypes = getInputsTypesForRegion(succRegionNo);
122     if (!sourceTypes.hasValue())
123       continue;
124 
125     TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
126     if (sourceTypes->size() != succInputsTypes.size()) {
127       InFlightDiagnostic diag = op->emitOpError(" region control flow edge ");
128       return printEdgeName(diag) << ": source has " << sourceTypes->size()
129                                  << " operands, but target successor needs "
130                                  << succInputsTypes.size();
131     }
132 
133     for (const auto &typesIdx :
134          llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) {
135       Type sourceType = std::get<0>(typesIdx.value());
136       Type inputType = std::get<1>(typesIdx.value());
137       if (!regionInterface.areTypesCompatible(sourceType, inputType)) {
138         InFlightDiagnostic diag = op->emitOpError(" along control flow edge ");
139         return printEdgeName(diag)
140                << ": source type #" << typesIdx.index() << " " << sourceType
141                << " should match input type #" << typesIdx.index() << " "
142                << inputType;
143       }
144     }
145   }
146   return success();
147 }
148 
149 /// Verify that types match along control flow edges described the given op.
150 LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
151   auto regionInterface = cast<RegionBranchOpInterface>(op);
152 
153   auto inputTypesFromParent = [&](Optional<unsigned> regionNo) -> TypeRange {
154     if (regionNo.hasValue()) {
155       return regionInterface.getSuccessorEntryOperands(regionNo.getValue())
156           .getTypes();
157     }
158 
159     // If the successor of a parent op is the parent itself
160     // RegionBranchOpInterface does not have an API to query what the entry
161     // operands will be in that case. Vend out the result types of the op in
162     // that case so that type checking succeeds for this case.
163     return op->getResultTypes();
164   };
165 
166   // Verify types along control flow edges originating from the parent.
167   if (failed(verifyTypesAlongAllEdges(op, llvm::None, inputTypesFromParent)))
168     return failure();
169 
170   // RegionBranchOpInterface should not be implemented by Ops that do not have
171   // attached regions.
172   assert(op->getNumRegions() != 0);
173 
174   auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) {
175     if (lhs.size() != rhs.size())
176       return false;
177     for (auto types : llvm::zip(lhs, rhs)) {
178       if (!regionInterface.areTypesCompatible(std::get<0>(types),
179                                               std::get<1>(types))) {
180         return false;
181       }
182     }
183     return true;
184   };
185 
186   // Verify types along control flow edges originating from each region.
187   for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) {
188     Region &region = op->getRegion(regionNo);
189 
190     // Since there can be multiple `ReturnLike` terminators or others
191     // implementing the `RegionBranchTerminatorOpInterface`, all should have the
192     // same operand types when passing them to the same region.
193 
194     Optional<OperandRange> regionReturnOperands;
195     for (Block &block : region) {
196       Operation *terminator = block.getTerminator();
197       auto terminatorOperands =
198           getRegionBranchSuccessorOperands(terminator, regionNo);
199       if (!terminatorOperands)
200         continue;
201 
202       if (!regionReturnOperands) {
203         regionReturnOperands = terminatorOperands;
204         continue;
205       }
206 
207       // Found more than one ReturnLike terminator. Make sure the operand types
208       // match with the first one.
209       if (!areTypesCompatible(regionReturnOperands->getTypes(),
210                               terminatorOperands->getTypes()))
211         return op->emitOpError("Region #")
212                << regionNo
213                << " operands mismatch between return-like terminators";
214     }
215 
216     auto inputTypesFromRegion =
217         [&](Optional<unsigned> regionNo) -> Optional<TypeRange> {
218       // If there is no return-like terminator, the op itself should verify
219       // type consistency.
220       if (!regionReturnOperands)
221         return llvm::None;
222 
223       // All successors get the same set of operand types.
224       return TypeRange(regionReturnOperands->getTypes());
225     };
226 
227     if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesFromRegion)))
228       return failure();
229   }
230 
231   return success();
232 }
233 
234 /// Return `true` if region `r` is reachable from region `begin` according to
235 /// the RegionBranchOpInterface (by taking a branch).
236 static bool isRegionReachable(Region *begin, Region *r) {
237   assert(begin->getParentOp() == r->getParentOp() &&
238          "expected that both regions belong to the same op");
239   auto op = cast<RegionBranchOpInterface>(begin->getParentOp());
240   SmallVector<bool> visited(op->getNumRegions(), false);
241   visited[begin->getRegionNumber()] = true;
242 
243   // Retrieve all successors of the region and enqueue them in the worklist.
244   SmallVector<unsigned> worklist;
245   auto enqueueAllSuccessors = [&](unsigned index) {
246     SmallVector<RegionSuccessor> successors;
247     op.getSuccessorRegions(index, successors);
248     for (RegionSuccessor successor : successors)
249       if (!successor.isParent())
250         worklist.push_back(successor.getSuccessor()->getRegionNumber());
251   };
252   enqueueAllSuccessors(begin->getRegionNumber());
253 
254   // Process all regions in the worklist via DFS.
255   while (!worklist.empty()) {
256     unsigned nextRegion = worklist.pop_back_val();
257     if (nextRegion == r->getRegionNumber())
258       return true;
259     if (visited[nextRegion])
260       continue;
261     visited[nextRegion] = true;
262     enqueueAllSuccessors(nextRegion);
263   }
264 
265   return false;
266 }
267 
268 /// Return `true` if `a` and `b` are in mutually exclusive regions.
269 ///
270 /// 1. Find the first common of `a` and `b` (ancestor) that implements
271 ///    RegionBranchOpInterface.
272 /// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are
273 ///    contained.
274 /// 3. Check if `regionA` and `regionB` are mutually exclusive. They are
275 ///    mutually exclusive if they are not reachable from each other as per
276 ///    RegionBranchOpInterface::getSuccessorRegions.
277 bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) {
278   assert(a && "expected non-empty operation");
279   assert(b && "expected non-empty operation");
280 
281   auto branchOp = a->getParentOfType<RegionBranchOpInterface>();
282   while (branchOp) {
283     // Check if b is inside branchOp. (We already know that a is.)
284     if (!branchOp->isProperAncestor(b)) {
285       // Check next enclosing RegionBranchOpInterface.
286       branchOp = branchOp->getParentOfType<RegionBranchOpInterface>();
287       continue;
288     }
289 
290     // b is contained in branchOp. Retrieve the regions in which `a` and `b`
291     // are contained.
292     Region *regionA = nullptr, *regionB = nullptr;
293     for (Region &r : branchOp->getRegions()) {
294       if (r.findAncestorOpInRegion(*a)) {
295         assert(!regionA && "already found a region for a");
296         regionA = &r;
297       }
298       if (r.findAncestorOpInRegion(*b)) {
299         assert(!regionB && "already found a region for b");
300         regionB = &r;
301       }
302     }
303     assert(regionA && regionB && "could not find region of op");
304 
305     // `a` and `b` are in mutually exclusive regions if both regions are
306     // distinct and neither region is reachable from the other region.
307     return regionA != regionB && !isRegionReachable(regionA, regionB) &&
308            !isRegionReachable(regionB, regionA);
309   }
310 
311   // Could not find a common RegionBranchOpInterface among a's and b's
312   // ancestors.
313   return false;
314 }
315 
316 bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
317   Region *region = &getOperation()->getRegion(index);
318   return isRegionReachable(region, region);
319 }
320 
321 void RegionBranchOpInterface::getSuccessorRegions(
322     Optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
323   unsigned numInputs = 0;
324   if (index) {
325     // If the predecessor is a region, get the number of operands from an
326     // exiting terminator in the region.
327     for (Block &block : getOperation()->getRegion(*index)) {
328       Operation *terminator = block.getTerminator();
329       if (getRegionBranchSuccessorOperands(terminator, *index)) {
330         numInputs = terminator->getNumOperands();
331         break;
332       }
333     }
334   } else {
335     // Otherwise, use the number of parent operation operands.
336     numInputs = getOperation()->getNumOperands();
337   }
338   SmallVector<Attribute, 2> operands(numInputs, nullptr);
339   getSuccessorRegions(index, operands, regions);
340 }
341 
342 Region *mlir::getEnclosingRepetitiveRegion(Operation *op) {
343   while (Region *region = op->getParentRegion()) {
344     op = region->getParentOp();
345     if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
346       if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
347         return region;
348   }
349   return nullptr;
350 }
351 
352 Region *mlir::getEnclosingRepetitiveRegion(Value value) {
353   Region *region = value.getParentRegion();
354   while (region) {
355     Operation *op = region->getParentOp();
356     if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
357       if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
358         return region;
359     region = op->getParentRegion();
360   }
361   return nullptr;
362 }
363 
364 //===----------------------------------------------------------------------===//
365 // RegionBranchTerminatorOpInterface
366 //===----------------------------------------------------------------------===//
367 
368 /// Returns true if the given operation is either annotated with the
369 /// `ReturnLike` trait or implements the `RegionBranchTerminatorOpInterface`.
370 bool mlir::isRegionReturnLike(Operation *operation) {
371   return dyn_cast<RegionBranchTerminatorOpInterface>(operation) ||
372          operation->hasTrait<OpTrait::ReturnLike>();
373 }
374 
375 /// Returns the mutable operands that are passed to the region with the given
376 /// `regionIndex`. If the operation does not implement the
377 /// `RegionBranchTerminatorOpInterface` and is not marked as `ReturnLike`, the
378 /// result will be `llvm::None`. In all other cases, the resulting
379 /// `OperandRange` represents all operands that are passed to the specified
380 /// successor region. If `regionIndex` is `llvm::None`, all operands that are
381 /// passed to the parent operation will be returned.
382 Optional<MutableOperandRange>
383 mlir::getMutableRegionBranchSuccessorOperands(Operation *operation,
384                                               Optional<unsigned> regionIndex) {
385   // Try to query a RegionBranchTerminatorOpInterface to determine
386   // all successor operands that will be passed to the successor
387   // input arguments.
388   if (auto regionTerminatorInterface =
389           dyn_cast<RegionBranchTerminatorOpInterface>(operation))
390     return regionTerminatorInterface.getMutableSuccessorOperands(regionIndex);
391 
392   // TODO: The ReturnLike trait should imply a default implementation of the
393   // RegionBranchTerminatorOpInterface. This would make this code significantly
394   // easier. Furthermore, this may even make this function obsolete.
395   if (operation->hasTrait<OpTrait::ReturnLike>())
396     return MutableOperandRange(operation);
397   return llvm::None;
398 }
399 
400 /// Returns the read only operands that are passed to the region with the given
401 /// `regionIndex`. See `getMutableRegionBranchSuccessorOperands` for more
402 /// information.
403 Optional<OperandRange>
404 mlir::getRegionBranchSuccessorOperands(Operation *operation,
405                                        Optional<unsigned> regionIndex) {
406   auto range = getMutableRegionBranchSuccessorOperands(operation, regionIndex);
407   return range ? Optional<OperandRange>(*range) : llvm::None;
408 }
409