17ce1e7abSRiver Riddle //===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===//
27ce1e7abSRiver Riddle //
37ce1e7abSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
47ce1e7abSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
57ce1e7abSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
67ce1e7abSRiver Riddle //
77ce1e7abSRiver Riddle //===----------------------------------------------------------------------===//
87ce1e7abSRiver Riddle 
9f9735be7SMehdi Amini #include <utility>
10f9735be7SMehdi Amini 
11537f2208SMogball #include "mlir/IR/BuiltinTypes.h"
127ce1e7abSRiver Riddle #include "mlir/Interfaces/ControlFlowInterfaces.h"
13a3ad8f92SRahul Joshi #include "llvm/ADT/SmallPtrSet.h"
147ce1e7abSRiver Riddle 
157ce1e7abSRiver Riddle using namespace mlir;
167ce1e7abSRiver Riddle 
177ce1e7abSRiver Riddle //===----------------------------------------------------------------------===//
187ce1e7abSRiver Riddle // ControlFlowInterfaces
197ce1e7abSRiver Riddle //===----------------------------------------------------------------------===//
207ce1e7abSRiver Riddle 
217ce1e7abSRiver Riddle #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
227ce1e7abSRiver Riddle 
SuccessorOperands(MutableOperandRange forwardedOperands)230c789db5SMarkus Böck SuccessorOperands::SuccessorOperands(MutableOperandRange forwardedOperands)
24f9735be7SMehdi Amini     : producedOperandCount(0), forwardedOperands(std::move(forwardedOperands)) {
25f9735be7SMehdi Amini }
260c789db5SMarkus Böck 
SuccessorOperands(unsigned int producedOperandCount,MutableOperandRange forwardedOperands)270c789db5SMarkus Böck SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount,
280c789db5SMarkus Böck                                      MutableOperandRange forwardedOperands)
290c789db5SMarkus Böck     : producedOperandCount(producedOperandCount),
300c789db5SMarkus Böck       forwardedOperands(std::move(forwardedOperands)) {}
310c789db5SMarkus Böck 
327ce1e7abSRiver Riddle //===----------------------------------------------------------------------===//
337ce1e7abSRiver Riddle // BranchOpInterface
347ce1e7abSRiver Riddle //===----------------------------------------------------------------------===//
357ce1e7abSRiver Riddle 
367ce1e7abSRiver Riddle /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
377ce1e7abSRiver Riddle /// successor if 'operandIndex' is within the range of 'operands', or None if
387ce1e7abSRiver Riddle /// `operandIndex` isn't a successor operand index.
39a3ad8f92SRahul Joshi Optional<BlockArgument>
getBranchSuccessorArgument(const SuccessorOperands & operands,unsigned operandIndex,Block * successor)400c789db5SMarkus Böck detail::getBranchSuccessorArgument(const SuccessorOperands &operands,
41a3ad8f92SRahul Joshi                                    unsigned operandIndex, Block *successor) {
420c789db5SMarkus Böck   OperandRange forwardedOperands = operands.getForwardedOperands();
437ce1e7abSRiver Riddle   // Check that the operands are valid.
440c789db5SMarkus Böck   if (forwardedOperands.empty())
457ce1e7abSRiver Riddle     return llvm::None;
467ce1e7abSRiver Riddle 
477ce1e7abSRiver Riddle   // Check to ensure that this operand is within the range.
480c789db5SMarkus Böck   unsigned operandsStart = forwardedOperands.getBeginOperandIndex();
497ce1e7abSRiver Riddle   if (operandIndex < operandsStart ||
500c789db5SMarkus Böck       operandIndex >= (operandsStart + forwardedOperands.size()))
517ce1e7abSRiver Riddle     return llvm::None;
527ce1e7abSRiver Riddle 
537ce1e7abSRiver Riddle   // Index the successor.
540c789db5SMarkus Böck   unsigned argIndex =
550c789db5SMarkus Böck       operands.getProducedOperandCount() + operandIndex - operandsStart;
567ce1e7abSRiver Riddle   return successor->getArgument(argIndex);
577ce1e7abSRiver Riddle }
587ce1e7abSRiver Riddle 
597ce1e7abSRiver Riddle /// Verify that the given operands match those of the given successor block.
607ce1e7abSRiver Riddle LogicalResult
verifyBranchSuccessorOperands(Operation * op,unsigned succNo,const SuccessorOperands & operands)61a3ad8f92SRahul Joshi detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
620c789db5SMarkus Böck                                       const SuccessorOperands &operands) {
637ce1e7abSRiver Riddle   // Check the count.
640c789db5SMarkus Böck   unsigned operandCount = operands.size();
657ce1e7abSRiver Riddle   Block *destBB = op->getSuccessor(succNo);
667ce1e7abSRiver Riddle   if (operandCount != destBB->getNumArguments())
677ce1e7abSRiver Riddle     return op->emitError() << "branch has " << operandCount
687ce1e7abSRiver Riddle                            << " operands for successor #" << succNo
697ce1e7abSRiver Riddle                            << ", but target block has "
707ce1e7abSRiver Riddle                            << destBB->getNumArguments();
717ce1e7abSRiver Riddle 
727ce1e7abSRiver Riddle   // Check the types.
730c789db5SMarkus Böck   for (unsigned i = operands.getProducedOperandCount(); i != operandCount;
740c789db5SMarkus Böck        ++i) {
75e7c7b16aSMogball     if (!cast<BranchOpInterface>(op).areTypesCompatible(
760c789db5SMarkus Böck             operands[i].getType(), destBB->getArgument(i).getType()))
777ce1e7abSRiver Riddle       return op->emitError() << "type mismatch for bb argument #" << i
787ce1e7abSRiver Riddle                              << " of successor #" << succNo;
797ce1e7abSRiver Riddle   }
807ce1e7abSRiver Riddle   return success();
817ce1e7abSRiver Riddle }
82a3ad8f92SRahul Joshi 
83a3ad8f92SRahul Joshi //===----------------------------------------------------------------------===//
84a3ad8f92SRahul Joshi // RegionBranchOpInterface
85a3ad8f92SRahul Joshi //===----------------------------------------------------------------------===//
86a3ad8f92SRahul Joshi 
87a3ad8f92SRahul Joshi /// Verify that types match along all region control flow edges originating from
88a3ad8f92SRahul Joshi /// `sourceNo` (region # if source is a region, llvm::None if source is parent
89a3ad8f92SRahul Joshi /// op). `getInputsTypesForRegion` is a function that returns the types of the
9079716559SAlex Zinenko /// inputs that flow from `sourceIndex' to the given region, or llvm::None if
9179716559SAlex Zinenko /// the exact type match verification is not necessary (e.g., if the Op verifies
9279716559SAlex Zinenko /// the match itself).
9379716559SAlex Zinenko static LogicalResult
verifyTypesAlongAllEdges(Operation * op,Optional<unsigned> sourceNo,function_ref<Optional<TypeRange> (Optional<unsigned>)> getInputsTypesForRegion)9479716559SAlex Zinenko verifyTypesAlongAllEdges(Operation *op, Optional<unsigned> sourceNo,
9579716559SAlex Zinenko                          function_ref<Optional<TypeRange>(Optional<unsigned>)>
9679716559SAlex Zinenko                              getInputsTypesForRegion) {
97a3ad8f92SRahul Joshi   auto regionInterface = cast<RegionBranchOpInterface>(op);
98a3ad8f92SRahul Joshi 
99a3ad8f92SRahul Joshi   SmallVector<RegionSuccessor, 2> successors;
100ee70039aSMogball   regionInterface.getSuccessorRegions(sourceNo, successors);
101a3ad8f92SRahul Joshi 
102a3ad8f92SRahul Joshi   for (RegionSuccessor &succ : successors) {
103a3ad8f92SRahul Joshi     Optional<unsigned> succRegionNo;
104a3ad8f92SRahul Joshi     if (!succ.isParent())
105a3ad8f92SRahul Joshi       succRegionNo = succ.getSuccessor()->getRegionNumber();
106a3ad8f92SRahul Joshi 
107a3ad8f92SRahul Joshi     auto printEdgeName = [&](InFlightDiagnostic &diag) -> InFlightDiagnostic & {
108a3ad8f92SRahul Joshi       diag << "from ";
109a3ad8f92SRahul Joshi       if (sourceNo)
110*c27d8152SKazu Hirata         diag << "Region #" << sourceNo.value();
111a3ad8f92SRahul Joshi       else
112be35264aSSean Silva         diag << "parent operands";
113a3ad8f92SRahul Joshi 
114a3ad8f92SRahul Joshi       diag << " to ";
115a3ad8f92SRahul Joshi       if (succRegionNo)
116*c27d8152SKazu Hirata         diag << "Region #" << succRegionNo.value();
117a3ad8f92SRahul Joshi       else
118be35264aSSean Silva         diag << "parent results";
119a3ad8f92SRahul Joshi       return diag;
120a3ad8f92SRahul Joshi     };
121a3ad8f92SRahul Joshi 
12279716559SAlex Zinenko     Optional<TypeRange> sourceTypes = getInputsTypesForRegion(succRegionNo);
123491d2701SKazu Hirata     if (!sourceTypes.has_value())
12479716559SAlex Zinenko       continue;
12579716559SAlex Zinenko 
126a3ad8f92SRahul Joshi     TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
12779716559SAlex Zinenko     if (sourceTypes->size() != succInputsTypes.size()) {
128a3ad8f92SRahul Joshi       InFlightDiagnostic diag = op->emitOpError(" region control flow edge ");
12979716559SAlex Zinenko       return printEdgeName(diag) << ": source has " << sourceTypes->size()
130be35264aSSean Silva                                  << " operands, but target successor needs "
131a3ad8f92SRahul Joshi                                  << succInputsTypes.size();
132a3ad8f92SRahul Joshi     }
133a3ad8f92SRahul Joshi 
134e4853be2SMehdi Amini     for (const auto &typesIdx :
13579716559SAlex Zinenko          llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) {
136a3ad8f92SRahul Joshi       Type sourceType = std::get<0>(typesIdx.value());
137a3ad8f92SRahul Joshi       Type inputType = std::get<1>(typesIdx.value());
138e7c7b16aSMogball       if (!regionInterface.areTypesCompatible(sourceType, inputType)) {
139a3ad8f92SRahul Joshi         InFlightDiagnostic diag = op->emitOpError(" along control flow edge ");
140a3ad8f92SRahul Joshi         return printEdgeName(diag)
141be35264aSSean Silva                << ": source type #" << typesIdx.index() << " " << sourceType
142be35264aSSean Silva                << " should match input type #" << typesIdx.index() << " "
143a3ad8f92SRahul Joshi                << inputType;
144a3ad8f92SRahul Joshi       }
145a3ad8f92SRahul Joshi     }
146a3ad8f92SRahul Joshi   }
147a3ad8f92SRahul Joshi   return success();
148a3ad8f92SRahul Joshi }
149a3ad8f92SRahul Joshi 
150a3ad8f92SRahul Joshi /// Verify that types match along control flow edges described the given op.
verifyTypesAlongControlFlowEdges(Operation * op)151a3ad8f92SRahul Joshi LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
152a3ad8f92SRahul Joshi   auto regionInterface = cast<RegionBranchOpInterface>(op);
153a3ad8f92SRahul Joshi 
154a3ad8f92SRahul Joshi   auto inputTypesFromParent = [&](Optional<unsigned> regionNo) -> TypeRange {
155537f2208SMogball     return regionInterface.getSuccessorEntryOperands(regionNo).getTypes();
156a3ad8f92SRahul Joshi   };
157a3ad8f92SRahul Joshi 
158a3ad8f92SRahul Joshi   // Verify types along control flow edges originating from the parent.
159a3ad8f92SRahul Joshi   if (failed(verifyTypesAlongAllEdges(op, llvm::None, inputTypesFromParent)))
160a3ad8f92SRahul Joshi     return failure();
161a3ad8f92SRahul Joshi 
162a3ad8f92SRahul Joshi   // RegionBranchOpInterface should not be implemented by Ops that do not have
163a3ad8f92SRahul Joshi   // attached regions.
164a3ad8f92SRahul Joshi   assert(op->getNumRegions() != 0);
165a3ad8f92SRahul Joshi 
166e7c7b16aSMogball   auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) {
167e7c7b16aSMogball     if (lhs.size() != rhs.size())
168e7c7b16aSMogball       return false;
169e7c7b16aSMogball     for (auto types : llvm::zip(lhs, rhs)) {
170e7c7b16aSMogball       if (!regionInterface.areTypesCompatible(std::get<0>(types),
171e7c7b16aSMogball                                               std::get<1>(types))) {
172e7c7b16aSMogball         return false;
173e7c7b16aSMogball       }
174e7c7b16aSMogball     }
175e7c7b16aSMogball     return true;
176e7c7b16aSMogball   };
177e7c7b16aSMogball 
178a3ad8f92SRahul Joshi   // Verify types along control flow edges originating from each region.
179a3ad8f92SRahul Joshi   for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) {
180a3ad8f92SRahul Joshi     Region &region = op->getRegion(regionNo);
181a3ad8f92SRahul Joshi 
18204253320SMarcel Koester     // Since there can be multiple `ReturnLike` terminators or others
18304253320SMarcel Koester     // implementing the `RegionBranchTerminatorOpInterface`, all should have the
18404253320SMarcel Koester     // same operand types when passing them to the same region.
185a3ad8f92SRahul Joshi 
18604253320SMarcel Koester     Optional<OperandRange> regionReturnOperands;
187a3ad8f92SRahul Joshi     for (Block &block : region) {
188a3ad8f92SRahul Joshi       Operation *terminator = block.getTerminator();
18904253320SMarcel Koester       auto terminatorOperands =
19004253320SMarcel Koester           getRegionBranchSuccessorOperands(terminator, regionNo);
19104253320SMarcel Koester       if (!terminatorOperands)
192a3ad8f92SRahul Joshi         continue;
193a3ad8f92SRahul Joshi 
19404253320SMarcel Koester       if (!regionReturnOperands) {
19504253320SMarcel Koester         regionReturnOperands = terminatorOperands;
196a3ad8f92SRahul Joshi         continue;
197a3ad8f92SRahul Joshi       }
198a3ad8f92SRahul Joshi 
199a3ad8f92SRahul Joshi       // Found more than one ReturnLike terminator. Make sure the operand types
200a3ad8f92SRahul Joshi       // match with the first one.
201e7c7b16aSMogball       if (!areTypesCompatible(regionReturnOperands->getTypes(),
202e7c7b16aSMogball                               terminatorOperands->getTypes()))
203a3ad8f92SRahul Joshi         return op->emitOpError("Region #")
204a3ad8f92SRahul Joshi                << regionNo
205a3ad8f92SRahul Joshi                << " operands mismatch between return-like terminators";
206a3ad8f92SRahul Joshi     }
207a3ad8f92SRahul Joshi 
20879716559SAlex Zinenko     auto inputTypesFromRegion =
20979716559SAlex Zinenko         [&](Optional<unsigned> regionNo) -> Optional<TypeRange> {
21079716559SAlex Zinenko       // If there is no return-like terminator, the op itself should verify
21179716559SAlex Zinenko       // type consistency.
21204253320SMarcel Koester       if (!regionReturnOperands)
21379716559SAlex Zinenko         return llvm::None;
21479716559SAlex Zinenko 
21504253320SMarcel Koester       // All successors get the same set of operand types.
21604253320SMarcel Koester       return TypeRange(regionReturnOperands->getTypes());
217a3ad8f92SRahul Joshi     };
218a3ad8f92SRahul Joshi 
219a3ad8f92SRahul Joshi     if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesFromRegion)))
220a3ad8f92SRahul Joshi       return failure();
221a3ad8f92SRahul Joshi   }
222a3ad8f92SRahul Joshi 
223a3ad8f92SRahul Joshi   return success();
224a3ad8f92SRahul Joshi }
22504253320SMarcel Koester 
226a3005a40SMatthias Springer /// Return `true` if region `r` is reachable from region `begin` according to
227a3005a40SMatthias Springer /// the RegionBranchOpInterface (by taking a branch).
isRegionReachable(Region * begin,Region * r)228a3005a40SMatthias Springer static bool isRegionReachable(Region *begin, Region *r) {
229a3005a40SMatthias Springer   assert(begin->getParentOp() == r->getParentOp() &&
230a3005a40SMatthias Springer          "expected that both regions belong to the same op");
231a3005a40SMatthias Springer   auto op = cast<RegionBranchOpInterface>(begin->getParentOp());
232a3005a40SMatthias Springer   SmallVector<bool> visited(op->getNumRegions(), false);
233a3005a40SMatthias Springer   visited[begin->getRegionNumber()] = true;
234a3005a40SMatthias Springer 
235a3005a40SMatthias Springer   // Retrieve all successors of the region and enqueue them in the worklist.
236a3005a40SMatthias Springer   SmallVector<unsigned> worklist;
237a3005a40SMatthias Springer   auto enqueueAllSuccessors = [&](unsigned index) {
238a3005a40SMatthias Springer     SmallVector<RegionSuccessor> successors;
239a3005a40SMatthias Springer     op.getSuccessorRegions(index, successors);
240a3005a40SMatthias Springer     for (RegionSuccessor successor : successors)
241a3005a40SMatthias Springer       if (!successor.isParent())
242a3005a40SMatthias Springer         worklist.push_back(successor.getSuccessor()->getRegionNumber());
243a3005a40SMatthias Springer   };
244a3005a40SMatthias Springer   enqueueAllSuccessors(begin->getRegionNumber());
245a3005a40SMatthias Springer 
246a3005a40SMatthias Springer   // Process all regions in the worklist via DFS.
247a3005a40SMatthias Springer   while (!worklist.empty()) {
248a3005a40SMatthias Springer     unsigned nextRegion = worklist.pop_back_val();
249a3005a40SMatthias Springer     if (nextRegion == r->getRegionNumber())
250a3005a40SMatthias Springer       return true;
251a3005a40SMatthias Springer     if (visited[nextRegion])
252a3005a40SMatthias Springer       continue;
253a3005a40SMatthias Springer     visited[nextRegion] = true;
254a3005a40SMatthias Springer     enqueueAllSuccessors(nextRegion);
255a3005a40SMatthias Springer   }
256a3005a40SMatthias Springer 
257a3005a40SMatthias Springer   return false;
258a3005a40SMatthias Springer }
259a3005a40SMatthias Springer 
260a5c2f782SMatthias Springer /// Return `true` if `a` and `b` are in mutually exclusive regions.
261a5c2f782SMatthias Springer ///
262a5c2f782SMatthias Springer /// 1. Find the first common of `a` and `b` (ancestor) that implements
263a5c2f782SMatthias Springer ///    RegionBranchOpInterface.
264a5c2f782SMatthias Springer /// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are
265a5c2f782SMatthias Springer ///    contained.
266a5c2f782SMatthias Springer /// 3. Check if `regionA` and `regionB` are mutually exclusive. They are
267a5c2f782SMatthias Springer ///    mutually exclusive if they are not reachable from each other as per
268a5c2f782SMatthias Springer ///    RegionBranchOpInterface::getSuccessorRegions.
insideMutuallyExclusiveRegions(Operation * a,Operation * b)269a5c2f782SMatthias Springer bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) {
270a5c2f782SMatthias Springer   assert(a && "expected non-empty operation");
271a5c2f782SMatthias Springer   assert(b && "expected non-empty operation");
272a5c2f782SMatthias Springer 
273a5c2f782SMatthias Springer   auto branchOp = a->getParentOfType<RegionBranchOpInterface>();
274a5c2f782SMatthias Springer   while (branchOp) {
275a5c2f782SMatthias Springer     // Check if b is inside branchOp. (We already know that a is.)
276a5c2f782SMatthias Springer     if (!branchOp->isProperAncestor(b)) {
277a5c2f782SMatthias Springer       // Check next enclosing RegionBranchOpInterface.
278a5c2f782SMatthias Springer       branchOp = branchOp->getParentOfType<RegionBranchOpInterface>();
279a5c2f782SMatthias Springer       continue;
280a5c2f782SMatthias Springer     }
281a5c2f782SMatthias Springer 
282a5c2f782SMatthias Springer     // b is contained in branchOp. Retrieve the regions in which `a` and `b`
283a5c2f782SMatthias Springer     // are contained.
284a5c2f782SMatthias Springer     Region *regionA = nullptr, *regionB = nullptr;
285a5c2f782SMatthias Springer     for (Region &r : branchOp->getRegions()) {
286a5c2f782SMatthias Springer       if (r.findAncestorOpInRegion(*a)) {
287a5c2f782SMatthias Springer         assert(!regionA && "already found a region for a");
288a5c2f782SMatthias Springer         regionA = &r;
289a5c2f782SMatthias Springer       }
290a5c2f782SMatthias Springer       if (r.findAncestorOpInRegion(*b)) {
291a5c2f782SMatthias Springer         assert(!regionB && "already found a region for b");
292a5c2f782SMatthias Springer         regionB = &r;
293a5c2f782SMatthias Springer       }
294a5c2f782SMatthias Springer     }
295a5c2f782SMatthias Springer     assert(regionA && regionB && "could not find region of op");
296a5c2f782SMatthias Springer 
297a3005a40SMatthias Springer     // `a` and `b` are in mutually exclusive regions if both regions are
298a3005a40SMatthias Springer     // distinct and neither region is reachable from the other region.
299a3005a40SMatthias Springer     return regionA != regionB && !isRegionReachable(regionA, regionB) &&
300a5c2f782SMatthias Springer            !isRegionReachable(regionB, regionA);
301a5c2f782SMatthias Springer   }
302a5c2f782SMatthias Springer 
303a5c2f782SMatthias Springer   // Could not find a common RegionBranchOpInterface among a's and b's
304a5c2f782SMatthias Springer   // ancestors.
305a5c2f782SMatthias Springer   return false;
306a5c2f782SMatthias Springer }
307a5c2f782SMatthias Springer 
isRepetitiveRegion(unsigned index)3080f4ba02dSMatthias Springer bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
309a3005a40SMatthias Springer   Region *region = &getOperation()->getRegion(index);
310a3005a40SMatthias Springer   return isRegionReachable(region, region);
3110f4ba02dSMatthias Springer }
3120f4ba02dSMatthias Springer 
getSuccessorRegions(Optional<unsigned> index,SmallVectorImpl<RegionSuccessor> & regions)313ee70039aSMogball void RegionBranchOpInterface::getSuccessorRegions(
314ee70039aSMogball     Optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
315ee70039aSMogball   unsigned numInputs = 0;
316ee70039aSMogball   if (index) {
317ee70039aSMogball     // If the predecessor is a region, get the number of operands from an
318ee70039aSMogball     // exiting terminator in the region.
319ee70039aSMogball     for (Block &block : getOperation()->getRegion(*index)) {
320ee70039aSMogball       Operation *terminator = block.getTerminator();
321ee70039aSMogball       if (getRegionBranchSuccessorOperands(terminator, *index)) {
322ee70039aSMogball         numInputs = terminator->getNumOperands();
323ee70039aSMogball         break;
324ee70039aSMogball       }
325ee70039aSMogball     }
326ee70039aSMogball   } else {
327ee70039aSMogball     // Otherwise, use the number of parent operation operands.
328ee70039aSMogball     numInputs = getOperation()->getNumOperands();
329ee70039aSMogball   }
330ee70039aSMogball   SmallVector<Attribute, 2> operands(numInputs, nullptr);
331ee70039aSMogball   getSuccessorRegions(index, operands, regions);
332ee70039aSMogball }
333ee70039aSMogball 
getEnclosingRepetitiveRegion(Operation * op)3340f4ba02dSMatthias Springer Region *mlir::getEnclosingRepetitiveRegion(Operation *op) {
3350f4ba02dSMatthias Springer   while (Region *region = op->getParentRegion()) {
3360f4ba02dSMatthias Springer     op = region->getParentOp();
3370f4ba02dSMatthias Springer     if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
3380f4ba02dSMatthias Springer       if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
3390f4ba02dSMatthias Springer         return region;
3400f4ba02dSMatthias Springer   }
3410f4ba02dSMatthias Springer   return nullptr;
3420f4ba02dSMatthias Springer }
3430f4ba02dSMatthias Springer 
getEnclosingRepetitiveRegion(Value value)3440f4ba02dSMatthias Springer Region *mlir::getEnclosingRepetitiveRegion(Value value) {
3450f4ba02dSMatthias Springer   Region *region = value.getParentRegion();
3460f4ba02dSMatthias Springer   while (region) {
3470f4ba02dSMatthias Springer     Operation *op = region->getParentOp();
3480f4ba02dSMatthias Springer     if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
3490f4ba02dSMatthias Springer       if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
3500f4ba02dSMatthias Springer         return region;
3510f4ba02dSMatthias Springer     region = op->getParentRegion();
3520f4ba02dSMatthias Springer   }
3530f4ba02dSMatthias Springer   return nullptr;
3540f4ba02dSMatthias Springer }
3550f4ba02dSMatthias Springer 
35604253320SMarcel Koester //===----------------------------------------------------------------------===//
35704253320SMarcel Koester // RegionBranchTerminatorOpInterface
35804253320SMarcel Koester //===----------------------------------------------------------------------===//
35904253320SMarcel Koester 
36004253320SMarcel Koester /// Returns true if the given operation is either annotated with the
36104253320SMarcel Koester /// `ReturnLike` trait or implements the `RegionBranchTerminatorOpInterface`.
isRegionReturnLike(Operation * operation)36204253320SMarcel Koester bool mlir::isRegionReturnLike(Operation *operation) {
36304253320SMarcel Koester   return dyn_cast<RegionBranchTerminatorOpInterface>(operation) ||
36404253320SMarcel Koester          operation->hasTrait<OpTrait::ReturnLike>();
36504253320SMarcel Koester }
36604253320SMarcel Koester 
36704253320SMarcel Koester /// Returns the mutable operands that are passed to the region with the given
36804253320SMarcel Koester /// `regionIndex`. If the operation does not implement the
36904253320SMarcel Koester /// `RegionBranchTerminatorOpInterface` and is not marked as `ReturnLike`, the
37004253320SMarcel Koester /// result will be `llvm::None`. In all other cases, the resulting
37104253320SMarcel Koester /// `OperandRange` represents all operands that are passed to the specified
37204253320SMarcel Koester /// successor region. If `regionIndex` is `llvm::None`, all operands that are
37304253320SMarcel Koester /// passed to the parent operation will be returned.
37404253320SMarcel Koester Optional<MutableOperandRange>
getMutableRegionBranchSuccessorOperands(Operation * operation,Optional<unsigned> regionIndex)37504253320SMarcel Koester mlir::getMutableRegionBranchSuccessorOperands(Operation *operation,
37604253320SMarcel Koester                                               Optional<unsigned> regionIndex) {
37704253320SMarcel Koester   // Try to query a RegionBranchTerminatorOpInterface to determine
37804253320SMarcel Koester   // all successor operands that will be passed to the successor
37904253320SMarcel Koester   // input arguments.
38004253320SMarcel Koester   if (auto regionTerminatorInterface =
38104253320SMarcel Koester           dyn_cast<RegionBranchTerminatorOpInterface>(operation))
38204253320SMarcel Koester     return regionTerminatorInterface.getMutableSuccessorOperands(regionIndex);
38304253320SMarcel Koester 
38404253320SMarcel Koester   // TODO: The ReturnLike trait should imply a default implementation of the
38504253320SMarcel Koester   // RegionBranchTerminatorOpInterface. This would make this code significantly
38604253320SMarcel Koester   // easier. Furthermore, this may even make this function obsolete.
38704253320SMarcel Koester   if (operation->hasTrait<OpTrait::ReturnLike>())
38804253320SMarcel Koester     return MutableOperandRange(operation);
38904253320SMarcel Koester   return llvm::None;
39004253320SMarcel Koester }
39104253320SMarcel Koester 
39204253320SMarcel Koester /// Returns the read only operands that are passed to the region with the given
39304253320SMarcel Koester /// `regionIndex`. See `getMutableRegionBranchSuccessorOperands` for more
39404253320SMarcel Koester /// information.
39504253320SMarcel Koester Optional<OperandRange>
getRegionBranchSuccessorOperands(Operation * operation,Optional<unsigned> regionIndex)39604253320SMarcel Koester mlir::getRegionBranchSuccessorOperands(Operation *operation,
39704253320SMarcel Koester                                        Optional<unsigned> regionIndex) {
39804253320SMarcel Koester   auto range = getMutableRegionBranchSuccessorOperands(operation, regionIndex);
39904253320SMarcel Koester   return range ? Optional<OperandRange>(*range) : llvm::None;
40004253320SMarcel Koester }
401