17ce1e7abSRiver Riddle //===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===//
27ce1e7abSRiver Riddle //
37ce1e7abSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
47ce1e7abSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
57ce1e7abSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
67ce1e7abSRiver Riddle //
77ce1e7abSRiver Riddle //===----------------------------------------------------------------------===//
87ce1e7abSRiver Riddle 
97ce1e7abSRiver Riddle #include "mlir/Interfaces/ControlFlowInterfaces.h"
1009f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h"
11a3ad8f92SRahul Joshi #include "llvm/ADT/SmallPtrSet.h"
127ce1e7abSRiver Riddle 
137ce1e7abSRiver Riddle using namespace mlir;
147ce1e7abSRiver Riddle 
157ce1e7abSRiver Riddle //===----------------------------------------------------------------------===//
167ce1e7abSRiver Riddle // ControlFlowInterfaces
177ce1e7abSRiver Riddle //===----------------------------------------------------------------------===//
187ce1e7abSRiver Riddle 
197ce1e7abSRiver Riddle #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
207ce1e7abSRiver Riddle 
217ce1e7abSRiver Riddle //===----------------------------------------------------------------------===//
227ce1e7abSRiver Riddle // BranchOpInterface
237ce1e7abSRiver Riddle //===----------------------------------------------------------------------===//
247ce1e7abSRiver Riddle 
257ce1e7abSRiver Riddle /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
267ce1e7abSRiver Riddle /// successor if 'operandIndex' is within the range of 'operands', or None if
277ce1e7abSRiver Riddle /// `operandIndex` isn't a successor operand index.
28a3ad8f92SRahul Joshi Optional<BlockArgument>
29a3ad8f92SRahul Joshi detail::getBranchSuccessorArgument(Optional<OperandRange> operands,
30a3ad8f92SRahul Joshi                                    unsigned operandIndex, Block *successor) {
317ce1e7abSRiver Riddle   // Check that the operands are valid.
327ce1e7abSRiver Riddle   if (!operands || operands->empty())
337ce1e7abSRiver Riddle     return llvm::None;
347ce1e7abSRiver Riddle 
357ce1e7abSRiver Riddle   // Check to ensure that this operand is within the range.
367ce1e7abSRiver Riddle   unsigned operandsStart = operands->getBeginOperandIndex();
377ce1e7abSRiver Riddle   if (operandIndex < operandsStart ||
387ce1e7abSRiver Riddle       operandIndex >= (operandsStart + operands->size()))
397ce1e7abSRiver Riddle     return llvm::None;
407ce1e7abSRiver Riddle 
417ce1e7abSRiver Riddle   // Index the successor.
427ce1e7abSRiver Riddle   unsigned argIndex = operandIndex - operandsStart;
437ce1e7abSRiver Riddle   return successor->getArgument(argIndex);
447ce1e7abSRiver Riddle }
457ce1e7abSRiver Riddle 
467ce1e7abSRiver Riddle /// Verify that the given operands match those of the given successor block.
477ce1e7abSRiver Riddle LogicalResult
48a3ad8f92SRahul Joshi detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
497ce1e7abSRiver Riddle                                       Optional<OperandRange> operands) {
507ce1e7abSRiver Riddle   if (!operands)
517ce1e7abSRiver Riddle     return success();
527ce1e7abSRiver Riddle 
537ce1e7abSRiver Riddle   // Check the count.
547ce1e7abSRiver Riddle   unsigned operandCount = operands->size();
557ce1e7abSRiver Riddle   Block *destBB = op->getSuccessor(succNo);
567ce1e7abSRiver Riddle   if (operandCount != destBB->getNumArguments())
577ce1e7abSRiver Riddle     return op->emitError() << "branch has " << operandCount
587ce1e7abSRiver Riddle                            << " operands for successor #" << succNo
597ce1e7abSRiver Riddle                            << ", but target block has "
607ce1e7abSRiver Riddle                            << destBB->getNumArguments();
617ce1e7abSRiver Riddle 
627ce1e7abSRiver Riddle   // Check the types.
637ce1e7abSRiver Riddle   auto operandIt = operands->begin();
647ce1e7abSRiver Riddle   for (unsigned i = 0; i != operandCount; ++i, ++operandIt) {
657ce1e7abSRiver Riddle     if ((*operandIt).getType() != destBB->getArgument(i).getType())
667ce1e7abSRiver Riddle       return op->emitError() << "type mismatch for bb argument #" << i
677ce1e7abSRiver Riddle                              << " of successor #" << succNo;
687ce1e7abSRiver Riddle   }
697ce1e7abSRiver Riddle   return success();
707ce1e7abSRiver Riddle }
71a3ad8f92SRahul Joshi 
72a3ad8f92SRahul Joshi //===----------------------------------------------------------------------===//
73a3ad8f92SRahul Joshi // RegionBranchOpInterface
74a3ad8f92SRahul Joshi //===----------------------------------------------------------------------===//
75a3ad8f92SRahul Joshi 
76bb0d5f76SEugene Zhulenev // A constant value to represent unknown number of region invocations.
77bb0d5f76SEugene Zhulenev const int64_t mlir::kUnknownNumRegionInvocations = -1;
78bb0d5f76SEugene Zhulenev 
79a3ad8f92SRahul Joshi /// Verify that types match along all region control flow edges originating from
80a3ad8f92SRahul Joshi /// `sourceNo` (region # if source is a region, llvm::None if source is parent
81a3ad8f92SRahul Joshi /// op). `getInputsTypesForRegion` is a function that returns the types of the
8279716559SAlex Zinenko /// inputs that flow from `sourceIndex' to the given region, or llvm::None if
8379716559SAlex Zinenko /// the exact type match verification is not necessary (e.g., if the Op verifies
8479716559SAlex Zinenko /// the match itself).
8579716559SAlex Zinenko static LogicalResult
8679716559SAlex Zinenko verifyTypesAlongAllEdges(Operation *op, Optional<unsigned> sourceNo,
8779716559SAlex Zinenko                          function_ref<Optional<TypeRange>(Optional<unsigned>)>
8879716559SAlex Zinenko                              getInputsTypesForRegion) {
89a3ad8f92SRahul Joshi   auto regionInterface = cast<RegionBranchOpInterface>(op);
90a3ad8f92SRahul Joshi 
91a3ad8f92SRahul Joshi   SmallVector<RegionSuccessor, 2> successors;
92a3ad8f92SRahul Joshi   unsigned numInputs;
93a3ad8f92SRahul Joshi   if (sourceNo) {
94a3ad8f92SRahul Joshi     Region &srcRegion = op->getRegion(sourceNo.getValue());
95a3ad8f92SRahul Joshi     numInputs = srcRegion.getNumArguments();
96a3ad8f92SRahul Joshi   } else {
97a3ad8f92SRahul Joshi     numInputs = op->getNumOperands();
98a3ad8f92SRahul Joshi   }
99a3ad8f92SRahul Joshi   SmallVector<Attribute, 2> operands(numInputs, nullptr);
100a3ad8f92SRahul Joshi   regionInterface.getSuccessorRegions(sourceNo, operands, successors);
101a3ad8f92SRahul Joshi 
102a3ad8f92SRahul Joshi   for (RegionSuccessor &succ : successors) {
103a3ad8f92SRahul Joshi     Optional<unsigned> succRegionNo;
104a3ad8f92SRahul Joshi     if (!succ.isParent())
105a3ad8f92SRahul Joshi       succRegionNo = succ.getSuccessor()->getRegionNumber();
106a3ad8f92SRahul Joshi 
107a3ad8f92SRahul Joshi     auto printEdgeName = [&](InFlightDiagnostic &diag) -> InFlightDiagnostic & {
108a3ad8f92SRahul Joshi       diag << "from ";
109a3ad8f92SRahul Joshi       if (sourceNo)
110a3ad8f92SRahul Joshi         diag << "Region #" << sourceNo.getValue();
111a3ad8f92SRahul Joshi       else
112be35264aSSean Silva         diag << "parent operands";
113a3ad8f92SRahul Joshi 
114a3ad8f92SRahul Joshi       diag << " to ";
115a3ad8f92SRahul Joshi       if (succRegionNo)
116a3ad8f92SRahul Joshi         diag << "Region #" << succRegionNo.getValue();
117a3ad8f92SRahul Joshi       else
118be35264aSSean Silva         diag << "parent results";
119a3ad8f92SRahul Joshi       return diag;
120a3ad8f92SRahul Joshi     };
121a3ad8f92SRahul Joshi 
12279716559SAlex Zinenko     Optional<TypeRange> sourceTypes = getInputsTypesForRegion(succRegionNo);
12379716559SAlex Zinenko     if (!sourceTypes.hasValue())
12479716559SAlex Zinenko       continue;
12579716559SAlex Zinenko 
126a3ad8f92SRahul Joshi     TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
12779716559SAlex Zinenko     if (sourceTypes->size() != succInputsTypes.size()) {
128a3ad8f92SRahul Joshi       InFlightDiagnostic diag = op->emitOpError(" region control flow edge ");
12979716559SAlex Zinenko       return printEdgeName(diag) << ": source has " << sourceTypes->size()
130be35264aSSean Silva                                  << " operands, but target successor needs "
131a3ad8f92SRahul Joshi                                  << succInputsTypes.size();
132a3ad8f92SRahul Joshi     }
133a3ad8f92SRahul Joshi 
134*e4853be2SMehdi Amini     for (const auto &typesIdx :
13579716559SAlex Zinenko          llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) {
136a3ad8f92SRahul Joshi       Type sourceType = std::get<0>(typesIdx.value());
137a3ad8f92SRahul Joshi       Type inputType = std::get<1>(typesIdx.value());
138a3ad8f92SRahul Joshi       if (sourceType != inputType) {
139a3ad8f92SRahul Joshi         InFlightDiagnostic diag = op->emitOpError(" along control flow edge ");
140a3ad8f92SRahul Joshi         return printEdgeName(diag)
141be35264aSSean Silva                << ": source type #" << typesIdx.index() << " " << sourceType
142be35264aSSean Silva                << " should match input type #" << typesIdx.index() << " "
143a3ad8f92SRahul Joshi                << inputType;
144a3ad8f92SRahul Joshi       }
145a3ad8f92SRahul Joshi     }
146a3ad8f92SRahul Joshi   }
147a3ad8f92SRahul Joshi   return success();
148a3ad8f92SRahul Joshi }
149a3ad8f92SRahul Joshi 
150a3ad8f92SRahul Joshi /// Verify that types match along control flow edges described the given op.
151a3ad8f92SRahul Joshi LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
152a3ad8f92SRahul Joshi   auto regionInterface = cast<RegionBranchOpInterface>(op);
153a3ad8f92SRahul Joshi 
154a3ad8f92SRahul Joshi   auto inputTypesFromParent = [&](Optional<unsigned> regionNo) -> TypeRange {
155a3ad8f92SRahul Joshi     if (regionNo.hasValue()) {
156a3ad8f92SRahul Joshi       return regionInterface.getSuccessorEntryOperands(regionNo.getValue())
157a3ad8f92SRahul Joshi           .getTypes();
158a3ad8f92SRahul Joshi     }
159a3ad8f92SRahul Joshi 
160a3ad8f92SRahul Joshi     // If the successor of a parent op is the parent itself
161a3ad8f92SRahul Joshi     // RegionBranchOpInterface does not have an API to query what the entry
162a3ad8f92SRahul Joshi     // operands will be in that case. Vend out the result types of the op in
163a3ad8f92SRahul Joshi     // that case so that type checking succeeds for this case.
164a3ad8f92SRahul Joshi     return op->getResultTypes();
165a3ad8f92SRahul Joshi   };
166a3ad8f92SRahul Joshi 
167a3ad8f92SRahul Joshi   // Verify types along control flow edges originating from the parent.
168a3ad8f92SRahul Joshi   if (failed(verifyTypesAlongAllEdges(op, llvm::None, inputTypesFromParent)))
169a3ad8f92SRahul Joshi     return failure();
170a3ad8f92SRahul Joshi 
171a3ad8f92SRahul Joshi   // RegionBranchOpInterface should not be implemented by Ops that do not have
172a3ad8f92SRahul Joshi   // attached regions.
173a3ad8f92SRahul Joshi   assert(op->getNumRegions() != 0);
174a3ad8f92SRahul Joshi 
175a3ad8f92SRahul Joshi   // Verify types along control flow edges originating from each region.
176a3ad8f92SRahul Joshi   for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) {
177a3ad8f92SRahul Joshi     Region &region = op->getRegion(regionNo);
178a3ad8f92SRahul Joshi 
17904253320SMarcel Koester     // Since there can be multiple `ReturnLike` terminators or others
18004253320SMarcel Koester     // implementing the `RegionBranchTerminatorOpInterface`, all should have the
18104253320SMarcel Koester     // same operand types when passing them to the same region.
182a3ad8f92SRahul Joshi 
18304253320SMarcel Koester     Optional<OperandRange> regionReturnOperands;
184a3ad8f92SRahul Joshi     for (Block &block : region) {
185a3ad8f92SRahul Joshi       Operation *terminator = block.getTerminator();
18604253320SMarcel Koester       auto terminatorOperands =
18704253320SMarcel Koester           getRegionBranchSuccessorOperands(terminator, regionNo);
18804253320SMarcel Koester       if (!terminatorOperands)
189a3ad8f92SRahul Joshi         continue;
190a3ad8f92SRahul Joshi 
19104253320SMarcel Koester       if (!regionReturnOperands) {
19204253320SMarcel Koester         regionReturnOperands = terminatorOperands;
193a3ad8f92SRahul Joshi         continue;
194a3ad8f92SRahul Joshi       }
195a3ad8f92SRahul Joshi 
196a3ad8f92SRahul Joshi       // Found more than one ReturnLike terminator. Make sure the operand types
197a3ad8f92SRahul Joshi       // match with the first one.
19804253320SMarcel Koester       if (regionReturnOperands->getTypes() != terminatorOperands->getTypes())
199a3ad8f92SRahul Joshi         return op->emitOpError("Region #")
200a3ad8f92SRahul Joshi                << regionNo
201a3ad8f92SRahul Joshi                << " operands mismatch between return-like terminators";
202a3ad8f92SRahul Joshi     }
203a3ad8f92SRahul Joshi 
20479716559SAlex Zinenko     auto inputTypesFromRegion =
20579716559SAlex Zinenko         [&](Optional<unsigned> regionNo) -> Optional<TypeRange> {
20679716559SAlex Zinenko       // If there is no return-like terminator, the op itself should verify
20779716559SAlex Zinenko       // type consistency.
20804253320SMarcel Koester       if (!regionReturnOperands)
20979716559SAlex Zinenko         return llvm::None;
21079716559SAlex Zinenko 
21104253320SMarcel Koester       // All successors get the same set of operand types.
21204253320SMarcel Koester       return TypeRange(regionReturnOperands->getTypes());
213a3ad8f92SRahul Joshi     };
214a3ad8f92SRahul Joshi 
215a3ad8f92SRahul Joshi     if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesFromRegion)))
216a3ad8f92SRahul Joshi       return failure();
217a3ad8f92SRahul Joshi   }
218a3ad8f92SRahul Joshi 
219a3ad8f92SRahul Joshi   return success();
220a3ad8f92SRahul Joshi }
22104253320SMarcel Koester 
222a5c2f782SMatthias Springer /// Return `true` if `a` and `b` are in mutually exclusive regions.
223a5c2f782SMatthias Springer ///
224a5c2f782SMatthias Springer /// 1. Find the first common of `a` and `b` (ancestor) that implements
225a5c2f782SMatthias Springer ///    RegionBranchOpInterface.
226a5c2f782SMatthias Springer /// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are
227a5c2f782SMatthias Springer ///    contained.
228a5c2f782SMatthias Springer /// 3. Check if `regionA` and `regionB` are mutually exclusive. They are
229a5c2f782SMatthias Springer ///    mutually exclusive if they are not reachable from each other as per
230a5c2f782SMatthias Springer ///    RegionBranchOpInterface::getSuccessorRegions.
231a5c2f782SMatthias Springer bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) {
232a5c2f782SMatthias Springer   assert(a && "expected non-empty operation");
233a5c2f782SMatthias Springer   assert(b && "expected non-empty operation");
234a5c2f782SMatthias Springer 
235a5c2f782SMatthias Springer   auto branchOp = a->getParentOfType<RegionBranchOpInterface>();
236a5c2f782SMatthias Springer   while (branchOp) {
237a5c2f782SMatthias Springer     // Check if b is inside branchOp. (We already know that a is.)
238a5c2f782SMatthias Springer     if (!branchOp->isProperAncestor(b)) {
239a5c2f782SMatthias Springer       // Check next enclosing RegionBranchOpInterface.
240a5c2f782SMatthias Springer       branchOp = branchOp->getParentOfType<RegionBranchOpInterface>();
241a5c2f782SMatthias Springer       continue;
242a5c2f782SMatthias Springer     }
243a5c2f782SMatthias Springer 
244a5c2f782SMatthias Springer     // b is contained in branchOp. Retrieve the regions in which `a` and `b`
245a5c2f782SMatthias Springer     // are contained.
246a5c2f782SMatthias Springer     Region *regionA = nullptr, *regionB = nullptr;
247a5c2f782SMatthias Springer     for (Region &r : branchOp->getRegions()) {
248a5c2f782SMatthias Springer       if (r.findAncestorOpInRegion(*a)) {
249a5c2f782SMatthias Springer         assert(!regionA && "already found a region for a");
250a5c2f782SMatthias Springer         regionA = &r;
251a5c2f782SMatthias Springer       }
252a5c2f782SMatthias Springer       if (r.findAncestorOpInRegion(*b)) {
253a5c2f782SMatthias Springer         assert(!regionB && "already found a region for b");
254a5c2f782SMatthias Springer         regionB = &r;
255a5c2f782SMatthias Springer       }
256a5c2f782SMatthias Springer     }
257a5c2f782SMatthias Springer     assert(regionA && regionB && "could not find region of op");
258a5c2f782SMatthias Springer 
259a5c2f782SMatthias Springer     // Helper function that checks if region `r` is reachable from region
260a5c2f782SMatthias Springer     // `begin`.
261a5c2f782SMatthias Springer     std::function<bool(Region *, Region *)> isRegionReachable =
262a5c2f782SMatthias Springer         [&](Region *begin, Region *r) {
263a5c2f782SMatthias Springer           if (begin == r)
264a5c2f782SMatthias Springer             return true;
265a5c2f782SMatthias Springer           if (begin == nullptr)
266a5c2f782SMatthias Springer             return false;
267a5c2f782SMatthias Springer           // Compute index of region.
268a5c2f782SMatthias Springer           int64_t beginIndex = -1;
269*e4853be2SMehdi Amini           for (const auto &it : llvm::enumerate(branchOp->getRegions()))
270a5c2f782SMatthias Springer             if (&it.value() == begin)
271a5c2f782SMatthias Springer               beginIndex = it.index();
272a5c2f782SMatthias Springer           assert(beginIndex != -1 && "could not find region in op");
273a5c2f782SMatthias Springer           // Retrieve all successors of the region.
274a5c2f782SMatthias Springer           SmallVector<RegionSuccessor> successors;
275a5c2f782SMatthias Springer           branchOp.getSuccessorRegions(beginIndex, successors);
276a5c2f782SMatthias Springer           // Call function recursively on all successors.
277a5c2f782SMatthias Springer           for (RegionSuccessor successor : successors)
278a5c2f782SMatthias Springer             if (isRegionReachable(successor.getSuccessor(), r))
279a5c2f782SMatthias Springer               return true;
280a5c2f782SMatthias Springer           return false;
281a5c2f782SMatthias Springer         };
282a5c2f782SMatthias Springer 
283a5c2f782SMatthias Springer     // `a` and `b` are in mutually exclusive regions if neither region is
284a5c2f782SMatthias Springer     // reachable from the other region.
285a5c2f782SMatthias Springer     return !isRegionReachable(regionA, regionB) &&
286a5c2f782SMatthias Springer            !isRegionReachable(regionB, regionA);
287a5c2f782SMatthias Springer   }
288a5c2f782SMatthias Springer 
289a5c2f782SMatthias Springer   // Could not find a common RegionBranchOpInterface among a's and b's
290a5c2f782SMatthias Springer   // ancestors.
291a5c2f782SMatthias Springer   return false;
292a5c2f782SMatthias Springer }
293a5c2f782SMatthias Springer 
29404253320SMarcel Koester //===----------------------------------------------------------------------===//
29504253320SMarcel Koester // RegionBranchTerminatorOpInterface
29604253320SMarcel Koester //===----------------------------------------------------------------------===//
29704253320SMarcel Koester 
29804253320SMarcel Koester /// Returns true if the given operation is either annotated with the
29904253320SMarcel Koester /// `ReturnLike` trait or implements the `RegionBranchTerminatorOpInterface`.
30004253320SMarcel Koester bool mlir::isRegionReturnLike(Operation *operation) {
30104253320SMarcel Koester   return dyn_cast<RegionBranchTerminatorOpInterface>(operation) ||
30204253320SMarcel Koester          operation->hasTrait<OpTrait::ReturnLike>();
30304253320SMarcel Koester }
30404253320SMarcel Koester 
30504253320SMarcel Koester /// Returns the mutable operands that are passed to the region with the given
30604253320SMarcel Koester /// `regionIndex`. If the operation does not implement the
30704253320SMarcel Koester /// `RegionBranchTerminatorOpInterface` and is not marked as `ReturnLike`, the
30804253320SMarcel Koester /// result will be `llvm::None`. In all other cases, the resulting
30904253320SMarcel Koester /// `OperandRange` represents all operands that are passed to the specified
31004253320SMarcel Koester /// successor region. If `regionIndex` is `llvm::None`, all operands that are
31104253320SMarcel Koester /// passed to the parent operation will be returned.
31204253320SMarcel Koester Optional<MutableOperandRange>
31304253320SMarcel Koester mlir::getMutableRegionBranchSuccessorOperands(Operation *operation,
31404253320SMarcel Koester                                               Optional<unsigned> regionIndex) {
31504253320SMarcel Koester   // Try to query a RegionBranchTerminatorOpInterface to determine
31604253320SMarcel Koester   // all successor operands that will be passed to the successor
31704253320SMarcel Koester   // input arguments.
31804253320SMarcel Koester   if (auto regionTerminatorInterface =
31904253320SMarcel Koester           dyn_cast<RegionBranchTerminatorOpInterface>(operation))
32004253320SMarcel Koester     return regionTerminatorInterface.getMutableSuccessorOperands(regionIndex);
32104253320SMarcel Koester 
32204253320SMarcel Koester   // TODO: The ReturnLike trait should imply a default implementation of the
32304253320SMarcel Koester   // RegionBranchTerminatorOpInterface. This would make this code significantly
32404253320SMarcel Koester   // easier. Furthermore, this may even make this function obsolete.
32504253320SMarcel Koester   if (operation->hasTrait<OpTrait::ReturnLike>())
32604253320SMarcel Koester     return MutableOperandRange(operation);
32704253320SMarcel Koester   return llvm::None;
32804253320SMarcel Koester }
32904253320SMarcel Koester 
33004253320SMarcel Koester /// Returns the read only operands that are passed to the region with the given
33104253320SMarcel Koester /// `regionIndex`. See `getMutableRegionBranchSuccessorOperands` for more
33204253320SMarcel Koester /// information.
33304253320SMarcel Koester Optional<OperandRange>
33404253320SMarcel Koester mlir::getRegionBranchSuccessorOperands(Operation *operation,
33504253320SMarcel Koester                                        Optional<unsigned> regionIndex) {
33604253320SMarcel Koester   auto range = getMutableRegionBranchSuccessorOperands(operation, regionIndex);
33704253320SMarcel Koester   return range ? Optional<OperandRange>(*range) : llvm::None;
33804253320SMarcel Koester }
339