1 //===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Interfaces/ControlFlowInterfaces.h"
10 #include "mlir/IR/BuiltinTypes.h"
11 #include "llvm/ADT/SmallPtrSet.h"
12 
13 using namespace mlir;
14 
15 //===----------------------------------------------------------------------===//
16 // ControlFlowInterfaces
17 //===----------------------------------------------------------------------===//
18 
19 #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
20 
21 SuccessorOperands::SuccessorOperands(MutableOperandRange forwardedOperands)
22     : producedOperandCount(0), forwardedOperands(forwardedOperands) {}
23 
24 SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount,
25                                      MutableOperandRange forwardedOperands)
26     : producedOperandCount(producedOperandCount),
27       forwardedOperands(std::move(forwardedOperands)) {}
28 
29 //===----------------------------------------------------------------------===//
30 // BranchOpInterface
31 //===----------------------------------------------------------------------===//
32 
33 /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
34 /// successor if 'operandIndex' is within the range of 'operands', or None if
35 /// `operandIndex` isn't a successor operand index.
36 Optional<BlockArgument>
37 detail::getBranchSuccessorArgument(const SuccessorOperands &operands,
38                                    unsigned operandIndex, Block *successor) {
39   OperandRange forwardedOperands = operands.getForwardedOperands();
40   // Check that the operands are valid.
41   if (forwardedOperands.empty())
42     return llvm::None;
43 
44   // Check to ensure that this operand is within the range.
45   unsigned operandsStart = forwardedOperands.getBeginOperandIndex();
46   if (operandIndex < operandsStart ||
47       operandIndex >= (operandsStart + forwardedOperands.size()))
48     return llvm::None;
49 
50   // Index the successor.
51   unsigned argIndex =
52       operands.getProducedOperandCount() + operandIndex - operandsStart;
53   return successor->getArgument(argIndex);
54 }
55 
56 /// Verify that the given operands match those of the given successor block.
57 LogicalResult
58 detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
59                                       const SuccessorOperands &operands) {
60   // Check the count.
61   unsigned operandCount = operands.size();
62   Block *destBB = op->getSuccessor(succNo);
63   if (operandCount != destBB->getNumArguments())
64     return op->emitError() << "branch has " << operandCount
65                            << " operands for successor #" << succNo
66                            << ", but target block has "
67                            << destBB->getNumArguments();
68 
69   // Check the types.
70   for (unsigned i = operands.getProducedOperandCount(); i != operandCount;
71        ++i) {
72     if (!cast<BranchOpInterface>(op).areTypesCompatible(
73             operands[i].getType(), destBB->getArgument(i).getType()))
74       return op->emitError() << "type mismatch for bb argument #" << i
75                              << " of successor #" << succNo;
76   }
77   return success();
78 }
79 
80 //===----------------------------------------------------------------------===//
81 // RegionBranchOpInterface
82 //===----------------------------------------------------------------------===//
83 
84 /// Verify that types match along all region control flow edges originating from
85 /// `sourceNo` (region # if source is a region, llvm::None if source is parent
86 /// op). `getInputsTypesForRegion` is a function that returns the types of the
87 /// inputs that flow from `sourceIndex' to the given region, or llvm::None if
88 /// the exact type match verification is not necessary (e.g., if the Op verifies
89 /// the match itself).
90 static LogicalResult
91 verifyTypesAlongAllEdges(Operation *op, Optional<unsigned> sourceNo,
92                          function_ref<Optional<TypeRange>(Optional<unsigned>)>
93                              getInputsTypesForRegion) {
94   auto regionInterface = cast<RegionBranchOpInterface>(op);
95 
96   SmallVector<RegionSuccessor, 2> successors;
97   unsigned numInputs;
98   if (sourceNo) {
99     Region &srcRegion = op->getRegion(sourceNo.getValue());
100     numInputs = srcRegion.getNumArguments();
101   } else {
102     numInputs = op->getNumOperands();
103   }
104   SmallVector<Attribute, 2> operands(numInputs, nullptr);
105   regionInterface.getSuccessorRegions(sourceNo, operands, successors);
106 
107   for (RegionSuccessor &succ : successors) {
108     Optional<unsigned> succRegionNo;
109     if (!succ.isParent())
110       succRegionNo = succ.getSuccessor()->getRegionNumber();
111 
112     auto printEdgeName = [&](InFlightDiagnostic &diag) -> InFlightDiagnostic & {
113       diag << "from ";
114       if (sourceNo)
115         diag << "Region #" << sourceNo.getValue();
116       else
117         diag << "parent operands";
118 
119       diag << " to ";
120       if (succRegionNo)
121         diag << "Region #" << succRegionNo.getValue();
122       else
123         diag << "parent results";
124       return diag;
125     };
126 
127     Optional<TypeRange> sourceTypes = getInputsTypesForRegion(succRegionNo);
128     if (!sourceTypes.hasValue())
129       continue;
130 
131     TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
132     if (sourceTypes->size() != succInputsTypes.size()) {
133       InFlightDiagnostic diag = op->emitOpError(" region control flow edge ");
134       return printEdgeName(diag) << ": source has " << sourceTypes->size()
135                                  << " operands, but target successor needs "
136                                  << succInputsTypes.size();
137     }
138 
139     for (const auto &typesIdx :
140          llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) {
141       Type sourceType = std::get<0>(typesIdx.value());
142       Type inputType = std::get<1>(typesIdx.value());
143       if (!regionInterface.areTypesCompatible(sourceType, inputType)) {
144         InFlightDiagnostic diag = op->emitOpError(" along control flow edge ");
145         return printEdgeName(diag)
146                << ": source type #" << typesIdx.index() << " " << sourceType
147                << " should match input type #" << typesIdx.index() << " "
148                << inputType;
149       }
150     }
151   }
152   return success();
153 }
154 
155 /// Verify that types match along control flow edges described the given op.
156 LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
157   auto regionInterface = cast<RegionBranchOpInterface>(op);
158 
159   auto inputTypesFromParent = [&](Optional<unsigned> regionNo) -> TypeRange {
160     if (regionNo.hasValue()) {
161       return regionInterface.getSuccessorEntryOperands(regionNo.getValue())
162           .getTypes();
163     }
164 
165     // If the successor of a parent op is the parent itself
166     // RegionBranchOpInterface does not have an API to query what the entry
167     // operands will be in that case. Vend out the result types of the op in
168     // that case so that type checking succeeds for this case.
169     return op->getResultTypes();
170   };
171 
172   // Verify types along control flow edges originating from the parent.
173   if (failed(verifyTypesAlongAllEdges(op, llvm::None, inputTypesFromParent)))
174     return failure();
175 
176   // RegionBranchOpInterface should not be implemented by Ops that do not have
177   // attached regions.
178   assert(op->getNumRegions() != 0);
179 
180   auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) {
181     if (lhs.size() != rhs.size())
182       return false;
183     for (auto types : llvm::zip(lhs, rhs)) {
184       if (!regionInterface.areTypesCompatible(std::get<0>(types),
185                                               std::get<1>(types))) {
186         return false;
187       }
188     }
189     return true;
190   };
191 
192   // Verify types along control flow edges originating from each region.
193   for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) {
194     Region &region = op->getRegion(regionNo);
195 
196     // Since there can be multiple `ReturnLike` terminators or others
197     // implementing the `RegionBranchTerminatorOpInterface`, all should have the
198     // same operand types when passing them to the same region.
199 
200     Optional<OperandRange> regionReturnOperands;
201     for (Block &block : region) {
202       Operation *terminator = block.getTerminator();
203       auto terminatorOperands =
204           getRegionBranchSuccessorOperands(terminator, regionNo);
205       if (!terminatorOperands)
206         continue;
207 
208       if (!regionReturnOperands) {
209         regionReturnOperands = terminatorOperands;
210         continue;
211       }
212 
213       // Found more than one ReturnLike terminator. Make sure the operand types
214       // match with the first one.
215       if (!areTypesCompatible(regionReturnOperands->getTypes(),
216                               terminatorOperands->getTypes()))
217         return op->emitOpError("Region #")
218                << regionNo
219                << " operands mismatch between return-like terminators";
220     }
221 
222     auto inputTypesFromRegion =
223         [&](Optional<unsigned> regionNo) -> Optional<TypeRange> {
224       // If there is no return-like terminator, the op itself should verify
225       // type consistency.
226       if (!regionReturnOperands)
227         return llvm::None;
228 
229       // All successors get the same set of operand types.
230       return TypeRange(regionReturnOperands->getTypes());
231     };
232 
233     if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesFromRegion)))
234       return failure();
235   }
236 
237   return success();
238 }
239 
240 /// Return `true` if `a` and `b` are in mutually exclusive regions.
241 ///
242 /// 1. Find the first common of `a` and `b` (ancestor) that implements
243 ///    RegionBranchOpInterface.
244 /// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are
245 ///    contained.
246 /// 3. Check if `regionA` and `regionB` are mutually exclusive. They are
247 ///    mutually exclusive if they are not reachable from each other as per
248 ///    RegionBranchOpInterface::getSuccessorRegions.
249 bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) {
250   assert(a && "expected non-empty operation");
251   assert(b && "expected non-empty operation");
252 
253   auto branchOp = a->getParentOfType<RegionBranchOpInterface>();
254   while (branchOp) {
255     // Check if b is inside branchOp. (We already know that a is.)
256     if (!branchOp->isProperAncestor(b)) {
257       // Check next enclosing RegionBranchOpInterface.
258       branchOp = branchOp->getParentOfType<RegionBranchOpInterface>();
259       continue;
260     }
261 
262     // b is contained in branchOp. Retrieve the regions in which `a` and `b`
263     // are contained.
264     Region *regionA = nullptr, *regionB = nullptr;
265     for (Region &r : branchOp->getRegions()) {
266       if (r.findAncestorOpInRegion(*a)) {
267         assert(!regionA && "already found a region for a");
268         regionA = &r;
269       }
270       if (r.findAncestorOpInRegion(*b)) {
271         assert(!regionB && "already found a region for b");
272         regionB = &r;
273       }
274     }
275     assert(regionA && regionB && "could not find region of op");
276 
277     // Helper function that checks if region `r` is reachable from region
278     // `begin`.
279     std::function<bool(Region *, Region *)> isRegionReachable =
280         [&](Region *begin, Region *r) {
281           if (begin == r)
282             return true;
283           if (begin == nullptr)
284             return false;
285           // Compute index of region.
286           int64_t beginIndex = -1;
287           for (const auto &it : llvm::enumerate(branchOp->getRegions()))
288             if (&it.value() == begin)
289               beginIndex = it.index();
290           assert(beginIndex != -1 && "could not find region in op");
291           // Retrieve all successors of the region.
292           SmallVector<RegionSuccessor> successors;
293           branchOp.getSuccessorRegions(beginIndex, successors);
294           // Call function recursively on all successors.
295           for (RegionSuccessor successor : successors)
296             if (isRegionReachable(successor.getSuccessor(), r))
297               return true;
298           return false;
299         };
300 
301     // `a` and `b` are in mutually exclusive regions if neither region is
302     // reachable from the other region.
303     return !isRegionReachable(regionA, regionB) &&
304            !isRegionReachable(regionB, regionA);
305   }
306 
307   // Could not find a common RegionBranchOpInterface among a's and b's
308   // ancestors.
309   return false;
310 }
311 
312 bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
313   SmallVector<bool> visited(getOperation()->getNumRegions(), false);
314   visited[index] = true;
315 
316   // Retrieve all successors of the region and enqueue them in the worklist.
317   SmallVector<unsigned> worklist;
318   auto enqueueAllSuccessors = [&](unsigned index) {
319     SmallVector<RegionSuccessor> successors;
320     this->getSuccessorRegions(index, successors);
321     for (RegionSuccessor successor : successors)
322       if (!successor.isParent())
323         worklist.push_back(successor.getSuccessor()->getRegionNumber());
324   };
325   enqueueAllSuccessors(index);
326 
327   // Process all regions in the worklist via DFS.
328   while (!worklist.empty()) {
329     unsigned nextRegion = worklist.pop_back_val();
330     if (nextRegion == index)
331       return true;
332     if (visited[nextRegion])
333       continue;
334     visited[nextRegion] = true;
335     enqueueAllSuccessors(nextRegion);
336   }
337 
338   return false;
339 }
340 
341 Region *mlir::getEnclosingRepetitiveRegion(Operation *op) {
342   while (Region *region = op->getParentRegion()) {
343     op = region->getParentOp();
344     if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
345       if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
346         return region;
347   }
348   return nullptr;
349 }
350 
351 Region *mlir::getEnclosingRepetitiveRegion(Value value) {
352   Region *region = value.getParentRegion();
353   while (region) {
354     Operation *op = region->getParentOp();
355     if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
356       if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
357         return region;
358     region = op->getParentRegion();
359   }
360   return nullptr;
361 }
362 
363 //===----------------------------------------------------------------------===//
364 // RegionBranchTerminatorOpInterface
365 //===----------------------------------------------------------------------===//
366 
367 /// Returns true if the given operation is either annotated with the
368 /// `ReturnLike` trait or implements the `RegionBranchTerminatorOpInterface`.
369 bool mlir::isRegionReturnLike(Operation *operation) {
370   return dyn_cast<RegionBranchTerminatorOpInterface>(operation) ||
371          operation->hasTrait<OpTrait::ReturnLike>();
372 }
373 
374 /// Returns the mutable operands that are passed to the region with the given
375 /// `regionIndex`. If the operation does not implement the
376 /// `RegionBranchTerminatorOpInterface` and is not marked as `ReturnLike`, the
377 /// result will be `llvm::None`. In all other cases, the resulting
378 /// `OperandRange` represents all operands that are passed to the specified
379 /// successor region. If `regionIndex` is `llvm::None`, all operands that are
380 /// passed to the parent operation will be returned.
381 Optional<MutableOperandRange>
382 mlir::getMutableRegionBranchSuccessorOperands(Operation *operation,
383                                               Optional<unsigned> regionIndex) {
384   // Try to query a RegionBranchTerminatorOpInterface to determine
385   // all successor operands that will be passed to the successor
386   // input arguments.
387   if (auto regionTerminatorInterface =
388           dyn_cast<RegionBranchTerminatorOpInterface>(operation))
389     return regionTerminatorInterface.getMutableSuccessorOperands(regionIndex);
390 
391   // TODO: The ReturnLike trait should imply a default implementation of the
392   // RegionBranchTerminatorOpInterface. This would make this code significantly
393   // easier. Furthermore, this may even make this function obsolete.
394   if (operation->hasTrait<OpTrait::ReturnLike>())
395     return MutableOperandRange(operation);
396   return llvm::None;
397 }
398 
399 /// Returns the read only operands that are passed to the region with the given
400 /// `regionIndex`. See `getMutableRegionBranchSuccessorOperands` for more
401 /// information.
402 Optional<OperandRange>
403 mlir::getRegionBranchSuccessorOperands(Operation *operation,
404                                        Optional<unsigned> regionIndex) {
405   auto range = getMutableRegionBranchSuccessorOperands(operation, regionIndex);
406   return range ? Optional<OperandRange>(*range) : llvm::None;
407 }
408