1 //===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/Interfaces/ControlFlowInterfaces.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "llvm/ADT/SmallPtrSet.h"
14 
15 using namespace mlir;
16 
17 //===----------------------------------------------------------------------===//
18 // ControlFlowInterfaces
19 //===----------------------------------------------------------------------===//
20 
21 #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
22 
23 SuccessorOperands::SuccessorOperands(MutableOperandRange forwardedOperands)
24     : producedOperandCount(0), forwardedOperands(std::move(forwardedOperands)) {
25 }
26 
27 SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount,
28                                      MutableOperandRange forwardedOperands)
29     : producedOperandCount(producedOperandCount),
30       forwardedOperands(std::move(forwardedOperands)) {}
31 
32 //===----------------------------------------------------------------------===//
33 // BranchOpInterface
34 //===----------------------------------------------------------------------===//
35 
36 /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
37 /// successor if 'operandIndex' is within the range of 'operands', or None if
38 /// `operandIndex` isn't a successor operand index.
39 Optional<BlockArgument>
40 detail::getBranchSuccessorArgument(const SuccessorOperands &operands,
41                                    unsigned operandIndex, Block *successor) {
42   OperandRange forwardedOperands = operands.getForwardedOperands();
43   // Check that the operands are valid.
44   if (forwardedOperands.empty())
45     return llvm::None;
46 
47   // Check to ensure that this operand is within the range.
48   unsigned operandsStart = forwardedOperands.getBeginOperandIndex();
49   if (operandIndex < operandsStart ||
50       operandIndex >= (operandsStart + forwardedOperands.size()))
51     return llvm::None;
52 
53   // Index the successor.
54   unsigned argIndex =
55       operands.getProducedOperandCount() + operandIndex - operandsStart;
56   return successor->getArgument(argIndex);
57 }
58 
59 /// Verify that the given operands match those of the given successor block.
60 LogicalResult
61 detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
62                                       const SuccessorOperands &operands) {
63   // Check the count.
64   unsigned operandCount = operands.size();
65   Block *destBB = op->getSuccessor(succNo);
66   if (operandCount != destBB->getNumArguments())
67     return op->emitError() << "branch has " << operandCount
68                            << " operands for successor #" << succNo
69                            << ", but target block has "
70                            << destBB->getNumArguments();
71 
72   // Check the types.
73   for (unsigned i = operands.getProducedOperandCount(); i != operandCount;
74        ++i) {
75     if (!cast<BranchOpInterface>(op).areTypesCompatible(
76             operands[i].getType(), destBB->getArgument(i).getType()))
77       return op->emitError() << "type mismatch for bb argument #" << i
78                              << " of successor #" << succNo;
79   }
80   return success();
81 }
82 
83 //===----------------------------------------------------------------------===//
84 // RegionBranchOpInterface
85 //===----------------------------------------------------------------------===//
86 
87 /// Verify that types match along all region control flow edges originating from
88 /// `sourceNo` (region # if source is a region, llvm::None if source is parent
89 /// op). `getInputsTypesForRegion` is a function that returns the types of the
90 /// inputs that flow from `sourceIndex' to the given region, or llvm::None if
91 /// the exact type match verification is not necessary (e.g., if the Op verifies
92 /// the match itself).
93 static LogicalResult
94 verifyTypesAlongAllEdges(Operation *op, Optional<unsigned> sourceNo,
95                          function_ref<Optional<TypeRange>(Optional<unsigned>)>
96                              getInputsTypesForRegion) {
97   auto regionInterface = cast<RegionBranchOpInterface>(op);
98 
99   SmallVector<RegionSuccessor, 2> successors;
100   unsigned numInputs;
101   if (sourceNo) {
102     Region &srcRegion = op->getRegion(sourceNo.getValue());
103     numInputs = srcRegion.getNumArguments();
104   } else {
105     numInputs = op->getNumOperands();
106   }
107   SmallVector<Attribute, 2> operands(numInputs, nullptr);
108   regionInterface.getSuccessorRegions(sourceNo, operands, successors);
109 
110   for (RegionSuccessor &succ : successors) {
111     Optional<unsigned> succRegionNo;
112     if (!succ.isParent())
113       succRegionNo = succ.getSuccessor()->getRegionNumber();
114 
115     auto printEdgeName = [&](InFlightDiagnostic &diag) -> InFlightDiagnostic & {
116       diag << "from ";
117       if (sourceNo)
118         diag << "Region #" << sourceNo.getValue();
119       else
120         diag << "parent operands";
121 
122       diag << " to ";
123       if (succRegionNo)
124         diag << "Region #" << succRegionNo.getValue();
125       else
126         diag << "parent results";
127       return diag;
128     };
129 
130     Optional<TypeRange> sourceTypes = getInputsTypesForRegion(succRegionNo);
131     if (!sourceTypes.hasValue())
132       continue;
133 
134     TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
135     if (sourceTypes->size() != succInputsTypes.size()) {
136       InFlightDiagnostic diag = op->emitOpError(" region control flow edge ");
137       return printEdgeName(diag) << ": source has " << sourceTypes->size()
138                                  << " operands, but target successor needs "
139                                  << succInputsTypes.size();
140     }
141 
142     for (const auto &typesIdx :
143          llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) {
144       Type sourceType = std::get<0>(typesIdx.value());
145       Type inputType = std::get<1>(typesIdx.value());
146       if (!regionInterface.areTypesCompatible(sourceType, inputType)) {
147         InFlightDiagnostic diag = op->emitOpError(" along control flow edge ");
148         return printEdgeName(diag)
149                << ": source type #" << typesIdx.index() << " " << sourceType
150                << " should match input type #" << typesIdx.index() << " "
151                << inputType;
152       }
153     }
154   }
155   return success();
156 }
157 
158 /// Verify that types match along control flow edges described the given op.
159 LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
160   auto regionInterface = cast<RegionBranchOpInterface>(op);
161 
162   auto inputTypesFromParent = [&](Optional<unsigned> regionNo) -> TypeRange {
163     if (regionNo.hasValue()) {
164       return regionInterface.getSuccessorEntryOperands(regionNo.getValue())
165           .getTypes();
166     }
167 
168     // If the successor of a parent op is the parent itself
169     // RegionBranchOpInterface does not have an API to query what the entry
170     // operands will be in that case. Vend out the result types of the op in
171     // that case so that type checking succeeds for this case.
172     return op->getResultTypes();
173   };
174 
175   // Verify types along control flow edges originating from the parent.
176   if (failed(verifyTypesAlongAllEdges(op, llvm::None, inputTypesFromParent)))
177     return failure();
178 
179   // RegionBranchOpInterface should not be implemented by Ops that do not have
180   // attached regions.
181   assert(op->getNumRegions() != 0);
182 
183   auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) {
184     if (lhs.size() != rhs.size())
185       return false;
186     for (auto types : llvm::zip(lhs, rhs)) {
187       if (!regionInterface.areTypesCompatible(std::get<0>(types),
188                                               std::get<1>(types))) {
189         return false;
190       }
191     }
192     return true;
193   };
194 
195   // Verify types along control flow edges originating from each region.
196   for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) {
197     Region &region = op->getRegion(regionNo);
198 
199     // Since there can be multiple `ReturnLike` terminators or others
200     // implementing the `RegionBranchTerminatorOpInterface`, all should have the
201     // same operand types when passing them to the same region.
202 
203     Optional<OperandRange> regionReturnOperands;
204     for (Block &block : region) {
205       Operation *terminator = block.getTerminator();
206       auto terminatorOperands =
207           getRegionBranchSuccessorOperands(terminator, regionNo);
208       if (!terminatorOperands)
209         continue;
210 
211       if (!regionReturnOperands) {
212         regionReturnOperands = terminatorOperands;
213         continue;
214       }
215 
216       // Found more than one ReturnLike terminator. Make sure the operand types
217       // match with the first one.
218       if (!areTypesCompatible(regionReturnOperands->getTypes(),
219                               terminatorOperands->getTypes()))
220         return op->emitOpError("Region #")
221                << regionNo
222                << " operands mismatch between return-like terminators";
223     }
224 
225     auto inputTypesFromRegion =
226         [&](Optional<unsigned> regionNo) -> Optional<TypeRange> {
227       // If there is no return-like terminator, the op itself should verify
228       // type consistency.
229       if (!regionReturnOperands)
230         return llvm::None;
231 
232       // All successors get the same set of operand types.
233       return TypeRange(regionReturnOperands->getTypes());
234     };
235 
236     if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesFromRegion)))
237       return failure();
238   }
239 
240   return success();
241 }
242 
243 /// Return `true` if region `r` is reachable from region `begin` according to
244 /// the RegionBranchOpInterface (by taking a branch).
245 static bool isRegionReachable(Region *begin, Region *r) {
246   assert(begin->getParentOp() == r->getParentOp() &&
247          "expected that both regions belong to the same op");
248   auto op = cast<RegionBranchOpInterface>(begin->getParentOp());
249   SmallVector<bool> visited(op->getNumRegions(), false);
250   visited[begin->getRegionNumber()] = true;
251 
252   // Retrieve all successors of the region and enqueue them in the worklist.
253   SmallVector<unsigned> worklist;
254   auto enqueueAllSuccessors = [&](unsigned index) {
255     SmallVector<RegionSuccessor> successors;
256     op.getSuccessorRegions(index, successors);
257     for (RegionSuccessor successor : successors)
258       if (!successor.isParent())
259         worklist.push_back(successor.getSuccessor()->getRegionNumber());
260   };
261   enqueueAllSuccessors(begin->getRegionNumber());
262 
263   // Process all regions in the worklist via DFS.
264   while (!worklist.empty()) {
265     unsigned nextRegion = worklist.pop_back_val();
266     if (nextRegion == r->getRegionNumber())
267       return true;
268     if (visited[nextRegion])
269       continue;
270     visited[nextRegion] = true;
271     enqueueAllSuccessors(nextRegion);
272   }
273 
274   return false;
275 }
276 
277 /// Return `true` if `a` and `b` are in mutually exclusive regions.
278 ///
279 /// 1. Find the first common of `a` and `b` (ancestor) that implements
280 ///    RegionBranchOpInterface.
281 /// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are
282 ///    contained.
283 /// 3. Check if `regionA` and `regionB` are mutually exclusive. They are
284 ///    mutually exclusive if they are not reachable from each other as per
285 ///    RegionBranchOpInterface::getSuccessorRegions.
286 bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) {
287   assert(a && "expected non-empty operation");
288   assert(b && "expected non-empty operation");
289 
290   auto branchOp = a->getParentOfType<RegionBranchOpInterface>();
291   while (branchOp) {
292     // Check if b is inside branchOp. (We already know that a is.)
293     if (!branchOp->isProperAncestor(b)) {
294       // Check next enclosing RegionBranchOpInterface.
295       branchOp = branchOp->getParentOfType<RegionBranchOpInterface>();
296       continue;
297     }
298 
299     // b is contained in branchOp. Retrieve the regions in which `a` and `b`
300     // are contained.
301     Region *regionA = nullptr, *regionB = nullptr;
302     for (Region &r : branchOp->getRegions()) {
303       if (r.findAncestorOpInRegion(*a)) {
304         assert(!regionA && "already found a region for a");
305         regionA = &r;
306       }
307       if (r.findAncestorOpInRegion(*b)) {
308         assert(!regionB && "already found a region for b");
309         regionB = &r;
310       }
311     }
312     assert(regionA && regionB && "could not find region of op");
313 
314     // `a` and `b` are in mutually exclusive regions if both regions are
315     // distinct and neither region is reachable from the other region.
316     return regionA != regionB && !isRegionReachable(regionA, regionB) &&
317            !isRegionReachable(regionB, regionA);
318   }
319 
320   // Could not find a common RegionBranchOpInterface among a's and b's
321   // ancestors.
322   return false;
323 }
324 
325 bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
326   Region *region = &getOperation()->getRegion(index);
327   return isRegionReachable(region, region);
328 }
329 
330 Region *mlir::getEnclosingRepetitiveRegion(Operation *op) {
331   while (Region *region = op->getParentRegion()) {
332     op = region->getParentOp();
333     if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
334       if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
335         return region;
336   }
337   return nullptr;
338 }
339 
340 Region *mlir::getEnclosingRepetitiveRegion(Value value) {
341   Region *region = value.getParentRegion();
342   while (region) {
343     Operation *op = region->getParentOp();
344     if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
345       if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
346         return region;
347     region = op->getParentRegion();
348   }
349   return nullptr;
350 }
351 
352 //===----------------------------------------------------------------------===//
353 // RegionBranchTerminatorOpInterface
354 //===----------------------------------------------------------------------===//
355 
356 /// Returns true if the given operation is either annotated with the
357 /// `ReturnLike` trait or implements the `RegionBranchTerminatorOpInterface`.
358 bool mlir::isRegionReturnLike(Operation *operation) {
359   return dyn_cast<RegionBranchTerminatorOpInterface>(operation) ||
360          operation->hasTrait<OpTrait::ReturnLike>();
361 }
362 
363 /// Returns the mutable operands that are passed to the region with the given
364 /// `regionIndex`. If the operation does not implement the
365 /// `RegionBranchTerminatorOpInterface` and is not marked as `ReturnLike`, the
366 /// result will be `llvm::None`. In all other cases, the resulting
367 /// `OperandRange` represents all operands that are passed to the specified
368 /// successor region. If `regionIndex` is `llvm::None`, all operands that are
369 /// passed to the parent operation will be returned.
370 Optional<MutableOperandRange>
371 mlir::getMutableRegionBranchSuccessorOperands(Operation *operation,
372                                               Optional<unsigned> regionIndex) {
373   // Try to query a RegionBranchTerminatorOpInterface to determine
374   // all successor operands that will be passed to the successor
375   // input arguments.
376   if (auto regionTerminatorInterface =
377           dyn_cast<RegionBranchTerminatorOpInterface>(operation))
378     return regionTerminatorInterface.getMutableSuccessorOperands(regionIndex);
379 
380   // TODO: The ReturnLike trait should imply a default implementation of the
381   // RegionBranchTerminatorOpInterface. This would make this code significantly
382   // easier. Furthermore, this may even make this function obsolete.
383   if (operation->hasTrait<OpTrait::ReturnLike>())
384     return MutableOperandRange(operation);
385   return llvm::None;
386 }
387 
388 /// Returns the read only operands that are passed to the region with the given
389 /// `regionIndex`. See `getMutableRegionBranchSuccessorOperands` for more
390 /// information.
391 Optional<OperandRange>
392 mlir::getRegionBranchSuccessorOperands(Operation *operation,
393                                        Optional<unsigned> regionIndex) {
394   auto range = getMutableRegionBranchSuccessorOperands(operation, regionIndex);
395   return range ? Optional<OperandRange>(*range) : llvm::None;
396 }
397