17ce1e7abSRiver Riddle //===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===//
27ce1e7abSRiver Riddle //
37ce1e7abSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
47ce1e7abSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
57ce1e7abSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
67ce1e7abSRiver Riddle //
77ce1e7abSRiver Riddle //===----------------------------------------------------------------------===//
87ce1e7abSRiver Riddle 
97ce1e7abSRiver Riddle #include "mlir/Interfaces/ControlFlowInterfaces.h"
107ce1e7abSRiver Riddle #include "mlir/IR/StandardTypes.h"
11a3ad8f92SRahul Joshi #include "llvm/ADT/SmallPtrSet.h"
127ce1e7abSRiver Riddle 
137ce1e7abSRiver Riddle using namespace mlir;
147ce1e7abSRiver Riddle 
157ce1e7abSRiver Riddle //===----------------------------------------------------------------------===//
167ce1e7abSRiver Riddle // ControlFlowInterfaces
177ce1e7abSRiver Riddle //===----------------------------------------------------------------------===//
187ce1e7abSRiver Riddle 
197ce1e7abSRiver Riddle #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
207ce1e7abSRiver Riddle 
217ce1e7abSRiver Riddle //===----------------------------------------------------------------------===//
227ce1e7abSRiver Riddle // BranchOpInterface
237ce1e7abSRiver Riddle //===----------------------------------------------------------------------===//
247ce1e7abSRiver Riddle 
257ce1e7abSRiver Riddle /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
267ce1e7abSRiver Riddle /// successor if 'operandIndex' is within the range of 'operands', or None if
277ce1e7abSRiver Riddle /// `operandIndex` isn't a successor operand index.
28a3ad8f92SRahul Joshi Optional<BlockArgument>
29a3ad8f92SRahul Joshi detail::getBranchSuccessorArgument(Optional<OperandRange> operands,
30a3ad8f92SRahul Joshi                                    unsigned operandIndex, Block *successor) {
317ce1e7abSRiver Riddle   // Check that the operands are valid.
327ce1e7abSRiver Riddle   if (!operands || operands->empty())
337ce1e7abSRiver Riddle     return llvm::None;
347ce1e7abSRiver Riddle 
357ce1e7abSRiver Riddle   // Check to ensure that this operand is within the range.
367ce1e7abSRiver Riddle   unsigned operandsStart = operands->getBeginOperandIndex();
377ce1e7abSRiver Riddle   if (operandIndex < operandsStart ||
387ce1e7abSRiver Riddle       operandIndex >= (operandsStart + operands->size()))
397ce1e7abSRiver Riddle     return llvm::None;
407ce1e7abSRiver Riddle 
417ce1e7abSRiver Riddle   // Index the successor.
427ce1e7abSRiver Riddle   unsigned argIndex = operandIndex - operandsStart;
437ce1e7abSRiver Riddle   return successor->getArgument(argIndex);
447ce1e7abSRiver Riddle }
457ce1e7abSRiver Riddle 
467ce1e7abSRiver Riddle /// Verify that the given operands match those of the given successor block.
477ce1e7abSRiver Riddle LogicalResult
48a3ad8f92SRahul Joshi detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
497ce1e7abSRiver Riddle                                       Optional<OperandRange> operands) {
507ce1e7abSRiver Riddle   if (!operands)
517ce1e7abSRiver Riddle     return success();
527ce1e7abSRiver Riddle 
537ce1e7abSRiver Riddle   // Check the count.
547ce1e7abSRiver Riddle   unsigned operandCount = operands->size();
557ce1e7abSRiver Riddle   Block *destBB = op->getSuccessor(succNo);
567ce1e7abSRiver Riddle   if (operandCount != destBB->getNumArguments())
577ce1e7abSRiver Riddle     return op->emitError() << "branch has " << operandCount
587ce1e7abSRiver Riddle                            << " operands for successor #" << succNo
597ce1e7abSRiver Riddle                            << ", but target block has "
607ce1e7abSRiver Riddle                            << destBB->getNumArguments();
617ce1e7abSRiver Riddle 
627ce1e7abSRiver Riddle   // Check the types.
637ce1e7abSRiver Riddle   auto operandIt = operands->begin();
647ce1e7abSRiver Riddle   for (unsigned i = 0; i != operandCount; ++i, ++operandIt) {
657ce1e7abSRiver Riddle     if ((*operandIt).getType() != destBB->getArgument(i).getType())
667ce1e7abSRiver Riddle       return op->emitError() << "type mismatch for bb argument #" << i
677ce1e7abSRiver Riddle                              << " of successor #" << succNo;
687ce1e7abSRiver Riddle   }
697ce1e7abSRiver Riddle   return success();
707ce1e7abSRiver Riddle }
71a3ad8f92SRahul Joshi 
72a3ad8f92SRahul Joshi //===----------------------------------------------------------------------===//
73a3ad8f92SRahul Joshi // RegionBranchOpInterface
74a3ad8f92SRahul Joshi //===----------------------------------------------------------------------===//
75a3ad8f92SRahul Joshi 
76a3ad8f92SRahul Joshi /// Verify that types match along all region control flow edges originating from
77a3ad8f92SRahul Joshi /// `sourceNo` (region # if source is a region, llvm::None if source is parent
78a3ad8f92SRahul Joshi /// op). `getInputsTypesForRegion` is a function that returns the types of the
79a3ad8f92SRahul Joshi /// inputs that flow from `sourceIndex' to the given region.
80a3ad8f92SRahul Joshi static LogicalResult verifyTypesAlongAllEdges(
81a3ad8f92SRahul Joshi     Operation *op, Optional<unsigned> sourceNo,
82a3ad8f92SRahul Joshi     function_ref<TypeRange(Optional<unsigned>)> getInputsTypesForRegion) {
83a3ad8f92SRahul Joshi   auto regionInterface = cast<RegionBranchOpInterface>(op);
84a3ad8f92SRahul Joshi 
85a3ad8f92SRahul Joshi   SmallVector<RegionSuccessor, 2> successors;
86a3ad8f92SRahul Joshi   unsigned numInputs;
87a3ad8f92SRahul Joshi   if (sourceNo) {
88a3ad8f92SRahul Joshi     Region &srcRegion = op->getRegion(sourceNo.getValue());
89a3ad8f92SRahul Joshi     numInputs = srcRegion.getNumArguments();
90a3ad8f92SRahul Joshi   } else {
91a3ad8f92SRahul Joshi     numInputs = op->getNumOperands();
92a3ad8f92SRahul Joshi   }
93a3ad8f92SRahul Joshi   SmallVector<Attribute, 2> operands(numInputs, nullptr);
94a3ad8f92SRahul Joshi   regionInterface.getSuccessorRegions(sourceNo, operands, successors);
95a3ad8f92SRahul Joshi 
96a3ad8f92SRahul Joshi   for (RegionSuccessor &succ : successors) {
97a3ad8f92SRahul Joshi     Optional<unsigned> succRegionNo;
98a3ad8f92SRahul Joshi     if (!succ.isParent())
99a3ad8f92SRahul Joshi       succRegionNo = succ.getSuccessor()->getRegionNumber();
100a3ad8f92SRahul Joshi 
101a3ad8f92SRahul Joshi     auto printEdgeName = [&](InFlightDiagnostic &diag) -> InFlightDiagnostic & {
102a3ad8f92SRahul Joshi       diag << "from ";
103a3ad8f92SRahul Joshi       if (sourceNo)
104a3ad8f92SRahul Joshi         diag << "Region #" << sourceNo.getValue();
105a3ad8f92SRahul Joshi       else
106be35264aSSean Silva         diag << "parent operands";
107a3ad8f92SRahul Joshi 
108a3ad8f92SRahul Joshi       diag << " to ";
109a3ad8f92SRahul Joshi       if (succRegionNo)
110a3ad8f92SRahul Joshi         diag << "Region #" << succRegionNo.getValue();
111a3ad8f92SRahul Joshi       else
112be35264aSSean Silva         diag << "parent results";
113a3ad8f92SRahul Joshi       return diag;
114a3ad8f92SRahul Joshi     };
115a3ad8f92SRahul Joshi 
116a3ad8f92SRahul Joshi     TypeRange sourceTypes = getInputsTypesForRegion(succRegionNo);
117a3ad8f92SRahul Joshi     TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
118a3ad8f92SRahul Joshi     if (sourceTypes.size() != succInputsTypes.size()) {
119a3ad8f92SRahul Joshi       InFlightDiagnostic diag = op->emitOpError(" region control flow edge ");
120be35264aSSean Silva       return printEdgeName(diag) << ": source has " << sourceTypes.size()
121be35264aSSean Silva                                  << " operands, but target successor needs "
122a3ad8f92SRahul Joshi                                  << succInputsTypes.size();
123a3ad8f92SRahul Joshi     }
124a3ad8f92SRahul Joshi 
125a3ad8f92SRahul Joshi     for (auto typesIdx :
126a3ad8f92SRahul Joshi          llvm::enumerate(llvm::zip(sourceTypes, succInputsTypes))) {
127a3ad8f92SRahul Joshi       Type sourceType = std::get<0>(typesIdx.value());
128a3ad8f92SRahul Joshi       Type inputType = std::get<1>(typesIdx.value());
129a3ad8f92SRahul Joshi       if (sourceType != inputType) {
130a3ad8f92SRahul Joshi         InFlightDiagnostic diag = op->emitOpError(" along control flow edge ");
131a3ad8f92SRahul Joshi         return printEdgeName(diag)
132be35264aSSean Silva                << ": source type #" << typesIdx.index() << " " << sourceType
133be35264aSSean Silva                << " should match input type #" << typesIdx.index() << " "
134a3ad8f92SRahul Joshi                << inputType;
135a3ad8f92SRahul Joshi       }
136a3ad8f92SRahul Joshi     }
137a3ad8f92SRahul Joshi   }
138a3ad8f92SRahul Joshi   return success();
139a3ad8f92SRahul Joshi }
140a3ad8f92SRahul Joshi 
141a3ad8f92SRahul Joshi /// Verify that types match along control flow edges described the given op.
142a3ad8f92SRahul Joshi LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
143a3ad8f92SRahul Joshi   auto regionInterface = cast<RegionBranchOpInterface>(op);
144a3ad8f92SRahul Joshi 
145a3ad8f92SRahul Joshi   auto inputTypesFromParent = [&](Optional<unsigned> regionNo) -> TypeRange {
146a3ad8f92SRahul Joshi     if (regionNo.hasValue()) {
147a3ad8f92SRahul Joshi       return regionInterface.getSuccessorEntryOperands(regionNo.getValue())
148a3ad8f92SRahul Joshi           .getTypes();
149a3ad8f92SRahul Joshi     }
150a3ad8f92SRahul Joshi 
151a3ad8f92SRahul Joshi     // If the successor of a parent op is the parent itself
152a3ad8f92SRahul Joshi     // RegionBranchOpInterface does not have an API to query what the entry
153a3ad8f92SRahul Joshi     // operands will be in that case. Vend out the result types of the op in
154a3ad8f92SRahul Joshi     // that case so that type checking succeeds for this case.
155a3ad8f92SRahul Joshi     return op->getResultTypes();
156a3ad8f92SRahul Joshi   };
157a3ad8f92SRahul Joshi 
158a3ad8f92SRahul Joshi   // Verify types along control flow edges originating from the parent.
159a3ad8f92SRahul Joshi   if (failed(verifyTypesAlongAllEdges(op, llvm::None, inputTypesFromParent)))
160a3ad8f92SRahul Joshi     return failure();
161a3ad8f92SRahul Joshi 
162a3ad8f92SRahul Joshi   // RegionBranchOpInterface should not be implemented by Ops that do not have
163a3ad8f92SRahul Joshi   // attached regions.
164a3ad8f92SRahul Joshi   assert(op->getNumRegions() != 0);
165a3ad8f92SRahul Joshi 
166a3ad8f92SRahul Joshi   // Verify types along control flow edges originating from each region.
167a3ad8f92SRahul Joshi   for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) {
168a3ad8f92SRahul Joshi     Region &region = op->getRegion(regionNo);
169a3ad8f92SRahul Joshi 
170*41b09f4eSKazuaki Ishizaki     // Since the interface cannot distinguish between different ReturnLike
171a3ad8f92SRahul Joshi     // ops within the region branching to different successors, all ReturnLike
172a3ad8f92SRahul Joshi     // ops in this region should have the same operand types. We will then use
173a3ad8f92SRahul Joshi     // one of them as the representative for type matching.
174a3ad8f92SRahul Joshi 
175a3ad8f92SRahul Joshi     Operation *regionReturn = nullptr;
176a3ad8f92SRahul Joshi     for (Block &block : region) {
177a3ad8f92SRahul Joshi       Operation *terminator = block.getTerminator();
178a3ad8f92SRahul Joshi       if (!terminator->hasTrait<OpTrait::ReturnLike>())
179a3ad8f92SRahul Joshi         continue;
180a3ad8f92SRahul Joshi 
181a3ad8f92SRahul Joshi       if (!regionReturn) {
182a3ad8f92SRahul Joshi         regionReturn = terminator;
183a3ad8f92SRahul Joshi         continue;
184a3ad8f92SRahul Joshi       }
185a3ad8f92SRahul Joshi 
186a3ad8f92SRahul Joshi       // Found more than one ReturnLike terminator. Make sure the operand types
187a3ad8f92SRahul Joshi       // match with the first one.
188a3ad8f92SRahul Joshi       if (regionReturn->getOperandTypes() != terminator->getOperandTypes())
189a3ad8f92SRahul Joshi         return op->emitOpError("Region #")
190a3ad8f92SRahul Joshi                << regionNo
191a3ad8f92SRahul Joshi                << " operands mismatch between return-like terminators";
192a3ad8f92SRahul Joshi     }
193a3ad8f92SRahul Joshi 
194a3ad8f92SRahul Joshi     auto inputTypesFromRegion = [&](Optional<unsigned> regionNo) -> TypeRange {
195a3ad8f92SRahul Joshi       // All successors get the same set of operands.
196a3ad8f92SRahul Joshi       return regionReturn ? TypeRange(regionReturn->getOperands().getTypes())
197a3ad8f92SRahul Joshi                           : TypeRange();
198a3ad8f92SRahul Joshi     };
199a3ad8f92SRahul Joshi 
200a3ad8f92SRahul Joshi     if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesFromRegion)))
201a3ad8f92SRahul Joshi       return failure();
202a3ad8f92SRahul Joshi   }
203a3ad8f92SRahul Joshi 
204a3ad8f92SRahul Joshi   return success();
205a3ad8f92SRahul Joshi }
206