1 //===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Interfaces/ControlFlowInterfaces.h"
10 #include "mlir/IR/StandardTypes.h"
11 #include "llvm/ADT/SmallPtrSet.h"
12 
13 using namespace mlir;
14 
15 //===----------------------------------------------------------------------===//
16 // ControlFlowInterfaces
17 //===----------------------------------------------------------------------===//
18 
19 #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
20 
21 //===----------------------------------------------------------------------===//
22 // BranchOpInterface
23 //===----------------------------------------------------------------------===//
24 
25 /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
26 /// successor if 'operandIndex' is within the range of 'operands', or None if
27 /// `operandIndex` isn't a successor operand index.
28 Optional<BlockArgument>
29 detail::getBranchSuccessorArgument(Optional<OperandRange> operands,
30                                    unsigned operandIndex, Block *successor) {
31   // Check that the operands are valid.
32   if (!operands || operands->empty())
33     return llvm::None;
34 
35   // Check to ensure that this operand is within the range.
36   unsigned operandsStart = operands->getBeginOperandIndex();
37   if (operandIndex < operandsStart ||
38       operandIndex >= (operandsStart + operands->size()))
39     return llvm::None;
40 
41   // Index the successor.
42   unsigned argIndex = operandIndex - operandsStart;
43   return successor->getArgument(argIndex);
44 }
45 
46 /// Verify that the given operands match those of the given successor block.
47 LogicalResult
48 detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
49                                       Optional<OperandRange> operands) {
50   if (!operands)
51     return success();
52 
53   // Check the count.
54   unsigned operandCount = operands->size();
55   Block *destBB = op->getSuccessor(succNo);
56   if (operandCount != destBB->getNumArguments())
57     return op->emitError() << "branch has " << operandCount
58                            << " operands for successor #" << succNo
59                            << ", but target block has "
60                            << destBB->getNumArguments();
61 
62   // Check the types.
63   auto operandIt = operands->begin();
64   for (unsigned i = 0; i != operandCount; ++i, ++operandIt) {
65     if ((*operandIt).getType() != destBB->getArgument(i).getType())
66       return op->emitError() << "type mismatch for bb argument #" << i
67                              << " of successor #" << succNo;
68   }
69   return success();
70 }
71 
72 //===----------------------------------------------------------------------===//
73 // RegionBranchOpInterface
74 //===----------------------------------------------------------------------===//
75 
76 /// Verify that types match along all region control flow edges originating from
77 /// `sourceNo` (region # if source is a region, llvm::None if source is parent
78 /// op). `getInputsTypesForRegion` is a function that returns the types of the
79 /// inputs that flow from `sourceIndex' to the given region.
80 static LogicalResult verifyTypesAlongAllEdges(
81     Operation *op, Optional<unsigned> sourceNo,
82     function_ref<TypeRange(Optional<unsigned>)> getInputsTypesForRegion) {
83   auto regionInterface = cast<RegionBranchOpInterface>(op);
84 
85   SmallVector<RegionSuccessor, 2> successors;
86   unsigned numInputs;
87   if (sourceNo) {
88     Region &srcRegion = op->getRegion(sourceNo.getValue());
89     numInputs = srcRegion.getNumArguments();
90   } else {
91     numInputs = op->getNumOperands();
92   }
93   SmallVector<Attribute, 2> operands(numInputs, nullptr);
94   regionInterface.getSuccessorRegions(sourceNo, operands, successors);
95 
96   for (RegionSuccessor &succ : successors) {
97     Optional<unsigned> succRegionNo;
98     if (!succ.isParent())
99       succRegionNo = succ.getSuccessor()->getRegionNumber();
100 
101     auto printEdgeName = [&](InFlightDiagnostic &diag) -> InFlightDiagnostic & {
102       diag << "from ";
103       if (sourceNo)
104         diag << "Region #" << sourceNo.getValue();
105       else
106         diag << "parent operands";
107 
108       diag << " to ";
109       if (succRegionNo)
110         diag << "Region #" << succRegionNo.getValue();
111       else
112         diag << "parent results";
113       return diag;
114     };
115 
116     TypeRange sourceTypes = getInputsTypesForRegion(succRegionNo);
117     TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
118     if (sourceTypes.size() != succInputsTypes.size()) {
119       InFlightDiagnostic diag = op->emitOpError(" region control flow edge ");
120       return printEdgeName(diag) << ": source has " << sourceTypes.size()
121                                  << " operands, but target successor needs "
122                                  << succInputsTypes.size();
123     }
124 
125     for (auto typesIdx :
126          llvm::enumerate(llvm::zip(sourceTypes, succInputsTypes))) {
127       Type sourceType = std::get<0>(typesIdx.value());
128       Type inputType = std::get<1>(typesIdx.value());
129       if (sourceType != inputType) {
130         InFlightDiagnostic diag = op->emitOpError(" along control flow edge ");
131         return printEdgeName(diag)
132                << ": source type #" << typesIdx.index() << " " << sourceType
133                << " should match input type #" << typesIdx.index() << " "
134                << inputType;
135       }
136     }
137   }
138   return success();
139 }
140 
141 /// Verify that types match along control flow edges described the given op.
142 LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
143   auto regionInterface = cast<RegionBranchOpInterface>(op);
144 
145   auto inputTypesFromParent = [&](Optional<unsigned> regionNo) -> TypeRange {
146     if (regionNo.hasValue()) {
147       return regionInterface.getSuccessorEntryOperands(regionNo.getValue())
148           .getTypes();
149     }
150 
151     // If the successor of a parent op is the parent itself
152     // RegionBranchOpInterface does not have an API to query what the entry
153     // operands will be in that case. Vend out the result types of the op in
154     // that case so that type checking succeeds for this case.
155     return op->getResultTypes();
156   };
157 
158   // Verify types along control flow edges originating from the parent.
159   if (failed(verifyTypesAlongAllEdges(op, llvm::None, inputTypesFromParent)))
160     return failure();
161 
162   // RegionBranchOpInterface should not be implemented by Ops that do not have
163   // attached regions.
164   assert(op->getNumRegions() != 0);
165 
166   // Verify types along control flow edges originating from each region.
167   for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) {
168     Region &region = op->getRegion(regionNo);
169 
170     // Since the interface cannnot distinguish between different ReturnLike
171     // ops within the region branching to different successors, all ReturnLike
172     // ops in this region should have the same operand types. We will then use
173     // one of them as the representative for type matching.
174 
175     Operation *regionReturn = nullptr;
176     for (Block &block : region) {
177       Operation *terminator = block.getTerminator();
178       if (!terminator->hasTrait<OpTrait::ReturnLike>())
179         continue;
180 
181       if (!regionReturn) {
182         regionReturn = terminator;
183         continue;
184       }
185 
186       // Found more than one ReturnLike terminator. Make sure the operand types
187       // match with the first one.
188       if (regionReturn->getOperandTypes() != terminator->getOperandTypes())
189         return op->emitOpError("Region #")
190                << regionNo
191                << " operands mismatch between return-like terminators";
192     }
193 
194     auto inputTypesFromRegion = [&](Optional<unsigned> regionNo) -> TypeRange {
195       // All successors get the same set of operands.
196       return regionReturn ? TypeRange(regionReturn->getOperands().getTypes())
197                           : TypeRange();
198     };
199 
200     if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesFromRegion)))
201       return failure();
202   }
203 
204   return success();
205 }
206