17ce1e7abSRiver Riddle //===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===//
27ce1e7abSRiver Riddle //
37ce1e7abSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
47ce1e7abSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
57ce1e7abSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
67ce1e7abSRiver Riddle //
77ce1e7abSRiver Riddle //===----------------------------------------------------------------------===//
87ce1e7abSRiver Riddle 
97ce1e7abSRiver Riddle #include "mlir/Interfaces/ControlFlowInterfaces.h"
107ce1e7abSRiver Riddle #include "mlir/IR/StandardTypes.h"
11a3ad8f92SRahul Joshi #include "llvm/ADT/SmallPtrSet.h"
127ce1e7abSRiver Riddle 
137ce1e7abSRiver Riddle using namespace mlir;
147ce1e7abSRiver Riddle 
157ce1e7abSRiver Riddle //===----------------------------------------------------------------------===//
167ce1e7abSRiver Riddle // ControlFlowInterfaces
177ce1e7abSRiver Riddle //===----------------------------------------------------------------------===//
187ce1e7abSRiver Riddle 
197ce1e7abSRiver Riddle #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
207ce1e7abSRiver Riddle 
217ce1e7abSRiver Riddle //===----------------------------------------------------------------------===//
227ce1e7abSRiver Riddle // BranchOpInterface
237ce1e7abSRiver Riddle //===----------------------------------------------------------------------===//
247ce1e7abSRiver Riddle 
257ce1e7abSRiver Riddle /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
267ce1e7abSRiver Riddle /// successor if 'operandIndex' is within the range of 'operands', or None if
277ce1e7abSRiver Riddle /// `operandIndex` isn't a successor operand index.
28a3ad8f92SRahul Joshi Optional<BlockArgument>
29a3ad8f92SRahul Joshi detail::getBranchSuccessorArgument(Optional<OperandRange> operands,
30a3ad8f92SRahul Joshi                                    unsigned operandIndex, Block *successor) {
317ce1e7abSRiver Riddle   // Check that the operands are valid.
327ce1e7abSRiver Riddle   if (!operands || operands->empty())
337ce1e7abSRiver Riddle     return llvm::None;
347ce1e7abSRiver Riddle 
357ce1e7abSRiver Riddle   // Check to ensure that this operand is within the range.
367ce1e7abSRiver Riddle   unsigned operandsStart = operands->getBeginOperandIndex();
377ce1e7abSRiver Riddle   if (operandIndex < operandsStart ||
387ce1e7abSRiver Riddle       operandIndex >= (operandsStart + operands->size()))
397ce1e7abSRiver Riddle     return llvm::None;
407ce1e7abSRiver Riddle 
417ce1e7abSRiver Riddle   // Index the successor.
427ce1e7abSRiver Riddle   unsigned argIndex = operandIndex - operandsStart;
437ce1e7abSRiver Riddle   return successor->getArgument(argIndex);
447ce1e7abSRiver Riddle }
457ce1e7abSRiver Riddle 
467ce1e7abSRiver Riddle /// Verify that the given operands match those of the given successor block.
477ce1e7abSRiver Riddle LogicalResult
48a3ad8f92SRahul Joshi detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
497ce1e7abSRiver Riddle                                       Optional<OperandRange> operands) {
507ce1e7abSRiver Riddle   if (!operands)
517ce1e7abSRiver Riddle     return success();
527ce1e7abSRiver Riddle 
537ce1e7abSRiver Riddle   // Check the count.
547ce1e7abSRiver Riddle   unsigned operandCount = operands->size();
557ce1e7abSRiver Riddle   Block *destBB = op->getSuccessor(succNo);
567ce1e7abSRiver Riddle   if (operandCount != destBB->getNumArguments())
577ce1e7abSRiver Riddle     return op->emitError() << "branch has " << operandCount
587ce1e7abSRiver Riddle                            << " operands for successor #" << succNo
597ce1e7abSRiver Riddle                            << ", but target block has "
607ce1e7abSRiver Riddle                            << destBB->getNumArguments();
617ce1e7abSRiver Riddle 
627ce1e7abSRiver Riddle   // Check the types.
637ce1e7abSRiver Riddle   auto operandIt = operands->begin();
647ce1e7abSRiver Riddle   for (unsigned i = 0; i != operandCount; ++i, ++operandIt) {
657ce1e7abSRiver Riddle     if ((*operandIt).getType() != destBB->getArgument(i).getType())
667ce1e7abSRiver Riddle       return op->emitError() << "type mismatch for bb argument #" << i
677ce1e7abSRiver Riddle                              << " of successor #" << succNo;
687ce1e7abSRiver Riddle   }
697ce1e7abSRiver Riddle   return success();
707ce1e7abSRiver Riddle }
71a3ad8f92SRahul Joshi 
72a3ad8f92SRahul Joshi //===----------------------------------------------------------------------===//
73a3ad8f92SRahul Joshi // RegionBranchOpInterface
74a3ad8f92SRahul Joshi //===----------------------------------------------------------------------===//
75a3ad8f92SRahul Joshi 
76*bb0d5f76SEugene Zhulenev // A constant value to represent unknown number of region invocations.
77*bb0d5f76SEugene Zhulenev const int64_t mlir::kUnknownNumRegionInvocations = -1;
78*bb0d5f76SEugene Zhulenev 
79a3ad8f92SRahul Joshi /// Verify that types match along all region control flow edges originating from
80a3ad8f92SRahul Joshi /// `sourceNo` (region # if source is a region, llvm::None if source is parent
81a3ad8f92SRahul Joshi /// op). `getInputsTypesForRegion` is a function that returns the types of the
8279716559SAlex Zinenko /// inputs that flow from `sourceIndex' to the given region, or llvm::None if
8379716559SAlex Zinenko /// the exact type match verification is not necessary (e.g., if the Op verifies
8479716559SAlex Zinenko /// the match itself).
8579716559SAlex Zinenko static LogicalResult
8679716559SAlex Zinenko verifyTypesAlongAllEdges(Operation *op, Optional<unsigned> sourceNo,
8779716559SAlex Zinenko                          function_ref<Optional<TypeRange>(Optional<unsigned>)>
8879716559SAlex Zinenko                              getInputsTypesForRegion) {
89a3ad8f92SRahul Joshi   auto regionInterface = cast<RegionBranchOpInterface>(op);
90a3ad8f92SRahul Joshi 
91a3ad8f92SRahul Joshi   SmallVector<RegionSuccessor, 2> successors;
92a3ad8f92SRahul Joshi   unsigned numInputs;
93a3ad8f92SRahul Joshi   if (sourceNo) {
94a3ad8f92SRahul Joshi     Region &srcRegion = op->getRegion(sourceNo.getValue());
95a3ad8f92SRahul Joshi     numInputs = srcRegion.getNumArguments();
96a3ad8f92SRahul Joshi   } else {
97a3ad8f92SRahul Joshi     numInputs = op->getNumOperands();
98a3ad8f92SRahul Joshi   }
99a3ad8f92SRahul Joshi   SmallVector<Attribute, 2> operands(numInputs, nullptr);
100a3ad8f92SRahul Joshi   regionInterface.getSuccessorRegions(sourceNo, operands, successors);
101a3ad8f92SRahul Joshi 
102a3ad8f92SRahul Joshi   for (RegionSuccessor &succ : successors) {
103a3ad8f92SRahul Joshi     Optional<unsigned> succRegionNo;
104a3ad8f92SRahul Joshi     if (!succ.isParent())
105a3ad8f92SRahul Joshi       succRegionNo = succ.getSuccessor()->getRegionNumber();
106a3ad8f92SRahul Joshi 
107a3ad8f92SRahul Joshi     auto printEdgeName = [&](InFlightDiagnostic &diag) -> InFlightDiagnostic & {
108a3ad8f92SRahul Joshi       diag << "from ";
109a3ad8f92SRahul Joshi       if (sourceNo)
110a3ad8f92SRahul Joshi         diag << "Region #" << sourceNo.getValue();
111a3ad8f92SRahul Joshi       else
112be35264aSSean Silva         diag << "parent operands";
113a3ad8f92SRahul Joshi 
114a3ad8f92SRahul Joshi       diag << " to ";
115a3ad8f92SRahul Joshi       if (succRegionNo)
116a3ad8f92SRahul Joshi         diag << "Region #" << succRegionNo.getValue();
117a3ad8f92SRahul Joshi       else
118be35264aSSean Silva         diag << "parent results";
119a3ad8f92SRahul Joshi       return diag;
120a3ad8f92SRahul Joshi     };
121a3ad8f92SRahul Joshi 
12279716559SAlex Zinenko     Optional<TypeRange> sourceTypes = getInputsTypesForRegion(succRegionNo);
12379716559SAlex Zinenko     if (!sourceTypes.hasValue())
12479716559SAlex Zinenko       continue;
12579716559SAlex Zinenko 
126a3ad8f92SRahul Joshi     TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
12779716559SAlex Zinenko     if (sourceTypes->size() != succInputsTypes.size()) {
128a3ad8f92SRahul Joshi       InFlightDiagnostic diag = op->emitOpError(" region control flow edge ");
12979716559SAlex Zinenko       return printEdgeName(diag) << ": source has " << sourceTypes->size()
130be35264aSSean Silva                                  << " operands, but target successor needs "
131a3ad8f92SRahul Joshi                                  << succInputsTypes.size();
132a3ad8f92SRahul Joshi     }
133a3ad8f92SRahul Joshi 
134a3ad8f92SRahul Joshi     for (auto typesIdx :
13579716559SAlex Zinenko          llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) {
136a3ad8f92SRahul Joshi       Type sourceType = std::get<0>(typesIdx.value());
137a3ad8f92SRahul Joshi       Type inputType = std::get<1>(typesIdx.value());
138a3ad8f92SRahul Joshi       if (sourceType != inputType) {
139a3ad8f92SRahul Joshi         InFlightDiagnostic diag = op->emitOpError(" along control flow edge ");
140a3ad8f92SRahul Joshi         return printEdgeName(diag)
141be35264aSSean Silva                << ": source type #" << typesIdx.index() << " " << sourceType
142be35264aSSean Silva                << " should match input type #" << typesIdx.index() << " "
143a3ad8f92SRahul Joshi                << inputType;
144a3ad8f92SRahul Joshi       }
145a3ad8f92SRahul Joshi     }
146a3ad8f92SRahul Joshi   }
147a3ad8f92SRahul Joshi   return success();
148a3ad8f92SRahul Joshi }
149a3ad8f92SRahul Joshi 
150a3ad8f92SRahul Joshi /// Verify that types match along control flow edges described the given op.
151a3ad8f92SRahul Joshi LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
152a3ad8f92SRahul Joshi   auto regionInterface = cast<RegionBranchOpInterface>(op);
153a3ad8f92SRahul Joshi 
154a3ad8f92SRahul Joshi   auto inputTypesFromParent = [&](Optional<unsigned> regionNo) -> TypeRange {
155a3ad8f92SRahul Joshi     if (regionNo.hasValue()) {
156a3ad8f92SRahul Joshi       return regionInterface.getSuccessorEntryOperands(regionNo.getValue())
157a3ad8f92SRahul Joshi           .getTypes();
158a3ad8f92SRahul Joshi     }
159a3ad8f92SRahul Joshi 
160a3ad8f92SRahul Joshi     // If the successor of a parent op is the parent itself
161a3ad8f92SRahul Joshi     // RegionBranchOpInterface does not have an API to query what the entry
162a3ad8f92SRahul Joshi     // operands will be in that case. Vend out the result types of the op in
163a3ad8f92SRahul Joshi     // that case so that type checking succeeds for this case.
164a3ad8f92SRahul Joshi     return op->getResultTypes();
165a3ad8f92SRahul Joshi   };
166a3ad8f92SRahul Joshi 
167a3ad8f92SRahul Joshi   // Verify types along control flow edges originating from the parent.
168a3ad8f92SRahul Joshi   if (failed(verifyTypesAlongAllEdges(op, llvm::None, inputTypesFromParent)))
169a3ad8f92SRahul Joshi     return failure();
170a3ad8f92SRahul Joshi 
171a3ad8f92SRahul Joshi   // RegionBranchOpInterface should not be implemented by Ops that do not have
172a3ad8f92SRahul Joshi   // attached regions.
173a3ad8f92SRahul Joshi   assert(op->getNumRegions() != 0);
174a3ad8f92SRahul Joshi 
175a3ad8f92SRahul Joshi   // Verify types along control flow edges originating from each region.
176a3ad8f92SRahul Joshi   for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) {
177a3ad8f92SRahul Joshi     Region &region = op->getRegion(regionNo);
178a3ad8f92SRahul Joshi 
17941b09f4eSKazuaki Ishizaki     // Since the interface cannot distinguish between different ReturnLike
180a3ad8f92SRahul Joshi     // ops within the region branching to different successors, all ReturnLike
181a3ad8f92SRahul Joshi     // ops in this region should have the same operand types. We will then use
182a3ad8f92SRahul Joshi     // one of them as the representative for type matching.
183a3ad8f92SRahul Joshi 
184a3ad8f92SRahul Joshi     Operation *regionReturn = nullptr;
185a3ad8f92SRahul Joshi     for (Block &block : region) {
186a3ad8f92SRahul Joshi       Operation *terminator = block.getTerminator();
187a3ad8f92SRahul Joshi       if (!terminator->hasTrait<OpTrait::ReturnLike>())
188a3ad8f92SRahul Joshi         continue;
189a3ad8f92SRahul Joshi 
190a3ad8f92SRahul Joshi       if (!regionReturn) {
191a3ad8f92SRahul Joshi         regionReturn = terminator;
192a3ad8f92SRahul Joshi         continue;
193a3ad8f92SRahul Joshi       }
194a3ad8f92SRahul Joshi 
195a3ad8f92SRahul Joshi       // Found more than one ReturnLike terminator. Make sure the operand types
196a3ad8f92SRahul Joshi       // match with the first one.
197a3ad8f92SRahul Joshi       if (regionReturn->getOperandTypes() != terminator->getOperandTypes())
198a3ad8f92SRahul Joshi         return op->emitOpError("Region #")
199a3ad8f92SRahul Joshi                << regionNo
200a3ad8f92SRahul Joshi                << " operands mismatch between return-like terminators";
201a3ad8f92SRahul Joshi     }
202a3ad8f92SRahul Joshi 
20379716559SAlex Zinenko     auto inputTypesFromRegion =
20479716559SAlex Zinenko         [&](Optional<unsigned> regionNo) -> Optional<TypeRange> {
20579716559SAlex Zinenko       // If there is no return-like terminator, the op itself should verify
20679716559SAlex Zinenko       // type consistency.
20779716559SAlex Zinenko       if (!regionReturn)
20879716559SAlex Zinenko         return llvm::None;
20979716559SAlex Zinenko 
210a3ad8f92SRahul Joshi       // All successors get the same set of operands.
21179716559SAlex Zinenko       return TypeRange(regionReturn->getOperands().getTypes());
212a3ad8f92SRahul Joshi     };
213a3ad8f92SRahul Joshi 
214a3ad8f92SRahul Joshi     if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesFromRegion)))
215a3ad8f92SRahul Joshi       return failure();
216a3ad8f92SRahul Joshi   }
217a3ad8f92SRahul Joshi 
218a3ad8f92SRahul Joshi   return success();
219a3ad8f92SRahul Joshi }
220