17ce1e7abSRiver Riddle //===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===// 27ce1e7abSRiver Riddle // 37ce1e7abSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 47ce1e7abSRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 57ce1e7abSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 67ce1e7abSRiver Riddle // 77ce1e7abSRiver Riddle //===----------------------------------------------------------------------===// 87ce1e7abSRiver Riddle 97ce1e7abSRiver Riddle #include "mlir/Interfaces/ControlFlowInterfaces.h" 1009f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h" 11a3ad8f92SRahul Joshi #include "llvm/ADT/SmallPtrSet.h" 127ce1e7abSRiver Riddle 137ce1e7abSRiver Riddle using namespace mlir; 147ce1e7abSRiver Riddle 157ce1e7abSRiver Riddle //===----------------------------------------------------------------------===// 167ce1e7abSRiver Riddle // ControlFlowInterfaces 177ce1e7abSRiver Riddle //===----------------------------------------------------------------------===// 187ce1e7abSRiver Riddle 197ce1e7abSRiver Riddle #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc" 207ce1e7abSRiver Riddle 217ce1e7abSRiver Riddle //===----------------------------------------------------------------------===// 227ce1e7abSRiver Riddle // BranchOpInterface 237ce1e7abSRiver Riddle //===----------------------------------------------------------------------===// 247ce1e7abSRiver Riddle 257ce1e7abSRiver Riddle /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some 267ce1e7abSRiver Riddle /// successor if 'operandIndex' is within the range of 'operands', or None if 277ce1e7abSRiver Riddle /// `operandIndex` isn't a successor operand index. 28a3ad8f92SRahul Joshi Optional<BlockArgument> 29a3ad8f92SRahul Joshi detail::getBranchSuccessorArgument(Optional<OperandRange> operands, 30a3ad8f92SRahul Joshi unsigned operandIndex, Block *successor) { 317ce1e7abSRiver Riddle // Check that the operands are valid. 327ce1e7abSRiver Riddle if (!operands || operands->empty()) 337ce1e7abSRiver Riddle return llvm::None; 347ce1e7abSRiver Riddle 357ce1e7abSRiver Riddle // Check to ensure that this operand is within the range. 367ce1e7abSRiver Riddle unsigned operandsStart = operands->getBeginOperandIndex(); 377ce1e7abSRiver Riddle if (operandIndex < operandsStart || 387ce1e7abSRiver Riddle operandIndex >= (operandsStart + operands->size())) 397ce1e7abSRiver Riddle return llvm::None; 407ce1e7abSRiver Riddle 417ce1e7abSRiver Riddle // Index the successor. 427ce1e7abSRiver Riddle unsigned argIndex = operandIndex - operandsStart; 437ce1e7abSRiver Riddle return successor->getArgument(argIndex); 447ce1e7abSRiver Riddle } 457ce1e7abSRiver Riddle 467ce1e7abSRiver Riddle /// Verify that the given operands match those of the given successor block. 477ce1e7abSRiver Riddle LogicalResult 48a3ad8f92SRahul Joshi detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo, 497ce1e7abSRiver Riddle Optional<OperandRange> operands) { 507ce1e7abSRiver Riddle if (!operands) 517ce1e7abSRiver Riddle return success(); 527ce1e7abSRiver Riddle 537ce1e7abSRiver Riddle // Check the count. 547ce1e7abSRiver Riddle unsigned operandCount = operands->size(); 557ce1e7abSRiver Riddle Block *destBB = op->getSuccessor(succNo); 567ce1e7abSRiver Riddle if (operandCount != destBB->getNumArguments()) 577ce1e7abSRiver Riddle return op->emitError() << "branch has " << operandCount 587ce1e7abSRiver Riddle << " operands for successor #" << succNo 597ce1e7abSRiver Riddle << ", but target block has " 607ce1e7abSRiver Riddle << destBB->getNumArguments(); 617ce1e7abSRiver Riddle 627ce1e7abSRiver Riddle // Check the types. 637ce1e7abSRiver Riddle auto operandIt = operands->begin(); 647ce1e7abSRiver Riddle for (unsigned i = 0; i != operandCount; ++i, ++operandIt) { 657ce1e7abSRiver Riddle if ((*operandIt).getType() != destBB->getArgument(i).getType()) 667ce1e7abSRiver Riddle return op->emitError() << "type mismatch for bb argument #" << i 677ce1e7abSRiver Riddle << " of successor #" << succNo; 687ce1e7abSRiver Riddle } 697ce1e7abSRiver Riddle return success(); 707ce1e7abSRiver Riddle } 71a3ad8f92SRahul Joshi 72a3ad8f92SRahul Joshi //===----------------------------------------------------------------------===// 73a3ad8f92SRahul Joshi // RegionBranchOpInterface 74a3ad8f92SRahul Joshi //===----------------------------------------------------------------------===// 75a3ad8f92SRahul Joshi 76bb0d5f76SEugene Zhulenev // A constant value to represent unknown number of region invocations. 77bb0d5f76SEugene Zhulenev const int64_t mlir::kUnknownNumRegionInvocations = -1; 78bb0d5f76SEugene Zhulenev 79a3ad8f92SRahul Joshi /// Verify that types match along all region control flow edges originating from 80a3ad8f92SRahul Joshi /// `sourceNo` (region # if source is a region, llvm::None if source is parent 81a3ad8f92SRahul Joshi /// op). `getInputsTypesForRegion` is a function that returns the types of the 8279716559SAlex Zinenko /// inputs that flow from `sourceIndex' to the given region, or llvm::None if 8379716559SAlex Zinenko /// the exact type match verification is not necessary (e.g., if the Op verifies 8479716559SAlex Zinenko /// the match itself). 8579716559SAlex Zinenko static LogicalResult 8679716559SAlex Zinenko verifyTypesAlongAllEdges(Operation *op, Optional<unsigned> sourceNo, 8779716559SAlex Zinenko function_ref<Optional<TypeRange>(Optional<unsigned>)> 8879716559SAlex Zinenko getInputsTypesForRegion) { 89a3ad8f92SRahul Joshi auto regionInterface = cast<RegionBranchOpInterface>(op); 90a3ad8f92SRahul Joshi 91a3ad8f92SRahul Joshi SmallVector<RegionSuccessor, 2> successors; 92a3ad8f92SRahul Joshi unsigned numInputs; 93a3ad8f92SRahul Joshi if (sourceNo) { 94a3ad8f92SRahul Joshi Region &srcRegion = op->getRegion(sourceNo.getValue()); 95a3ad8f92SRahul Joshi numInputs = srcRegion.getNumArguments(); 96a3ad8f92SRahul Joshi } else { 97a3ad8f92SRahul Joshi numInputs = op->getNumOperands(); 98a3ad8f92SRahul Joshi } 99a3ad8f92SRahul Joshi SmallVector<Attribute, 2> operands(numInputs, nullptr); 100a3ad8f92SRahul Joshi regionInterface.getSuccessorRegions(sourceNo, operands, successors); 101a3ad8f92SRahul Joshi 102a3ad8f92SRahul Joshi for (RegionSuccessor &succ : successors) { 103a3ad8f92SRahul Joshi Optional<unsigned> succRegionNo; 104a3ad8f92SRahul Joshi if (!succ.isParent()) 105a3ad8f92SRahul Joshi succRegionNo = succ.getSuccessor()->getRegionNumber(); 106a3ad8f92SRahul Joshi 107a3ad8f92SRahul Joshi auto printEdgeName = [&](InFlightDiagnostic &diag) -> InFlightDiagnostic & { 108a3ad8f92SRahul Joshi diag << "from "; 109a3ad8f92SRahul Joshi if (sourceNo) 110a3ad8f92SRahul Joshi diag << "Region #" << sourceNo.getValue(); 111a3ad8f92SRahul Joshi else 112be35264aSSean Silva diag << "parent operands"; 113a3ad8f92SRahul Joshi 114a3ad8f92SRahul Joshi diag << " to "; 115a3ad8f92SRahul Joshi if (succRegionNo) 116a3ad8f92SRahul Joshi diag << "Region #" << succRegionNo.getValue(); 117a3ad8f92SRahul Joshi else 118be35264aSSean Silva diag << "parent results"; 119a3ad8f92SRahul Joshi return diag; 120a3ad8f92SRahul Joshi }; 121a3ad8f92SRahul Joshi 12279716559SAlex Zinenko Optional<TypeRange> sourceTypes = getInputsTypesForRegion(succRegionNo); 12379716559SAlex Zinenko if (!sourceTypes.hasValue()) 12479716559SAlex Zinenko continue; 12579716559SAlex Zinenko 126a3ad8f92SRahul Joshi TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes(); 12779716559SAlex Zinenko if (sourceTypes->size() != succInputsTypes.size()) { 128a3ad8f92SRahul Joshi InFlightDiagnostic diag = op->emitOpError(" region control flow edge "); 12979716559SAlex Zinenko return printEdgeName(diag) << ": source has " << sourceTypes->size() 130be35264aSSean Silva << " operands, but target successor needs " 131a3ad8f92SRahul Joshi << succInputsTypes.size(); 132a3ad8f92SRahul Joshi } 133a3ad8f92SRahul Joshi 134*e4853be2SMehdi Amini for (const auto &typesIdx : 13579716559SAlex Zinenko llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) { 136a3ad8f92SRahul Joshi Type sourceType = std::get<0>(typesIdx.value()); 137a3ad8f92SRahul Joshi Type inputType = std::get<1>(typesIdx.value()); 138a3ad8f92SRahul Joshi if (sourceType != inputType) { 139a3ad8f92SRahul Joshi InFlightDiagnostic diag = op->emitOpError(" along control flow edge "); 140a3ad8f92SRahul Joshi return printEdgeName(diag) 141be35264aSSean Silva << ": source type #" << typesIdx.index() << " " << sourceType 142be35264aSSean Silva << " should match input type #" << typesIdx.index() << " " 143a3ad8f92SRahul Joshi << inputType; 144a3ad8f92SRahul Joshi } 145a3ad8f92SRahul Joshi } 146a3ad8f92SRahul Joshi } 147a3ad8f92SRahul Joshi return success(); 148a3ad8f92SRahul Joshi } 149a3ad8f92SRahul Joshi 150a3ad8f92SRahul Joshi /// Verify that types match along control flow edges described the given op. 151a3ad8f92SRahul Joshi LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) { 152a3ad8f92SRahul Joshi auto regionInterface = cast<RegionBranchOpInterface>(op); 153a3ad8f92SRahul Joshi 154a3ad8f92SRahul Joshi auto inputTypesFromParent = [&](Optional<unsigned> regionNo) -> TypeRange { 155a3ad8f92SRahul Joshi if (regionNo.hasValue()) { 156a3ad8f92SRahul Joshi return regionInterface.getSuccessorEntryOperands(regionNo.getValue()) 157a3ad8f92SRahul Joshi .getTypes(); 158a3ad8f92SRahul Joshi } 159a3ad8f92SRahul Joshi 160a3ad8f92SRahul Joshi // If the successor of a parent op is the parent itself 161a3ad8f92SRahul Joshi // RegionBranchOpInterface does not have an API to query what the entry 162a3ad8f92SRahul Joshi // operands will be in that case. Vend out the result types of the op in 163a3ad8f92SRahul Joshi // that case so that type checking succeeds for this case. 164a3ad8f92SRahul Joshi return op->getResultTypes(); 165a3ad8f92SRahul Joshi }; 166a3ad8f92SRahul Joshi 167a3ad8f92SRahul Joshi // Verify types along control flow edges originating from the parent. 168a3ad8f92SRahul Joshi if (failed(verifyTypesAlongAllEdges(op, llvm::None, inputTypesFromParent))) 169a3ad8f92SRahul Joshi return failure(); 170a3ad8f92SRahul Joshi 171a3ad8f92SRahul Joshi // RegionBranchOpInterface should not be implemented by Ops that do not have 172a3ad8f92SRahul Joshi // attached regions. 173a3ad8f92SRahul Joshi assert(op->getNumRegions() != 0); 174a3ad8f92SRahul Joshi 175a3ad8f92SRahul Joshi // Verify types along control flow edges originating from each region. 176a3ad8f92SRahul Joshi for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) { 177a3ad8f92SRahul Joshi Region ®ion = op->getRegion(regionNo); 178a3ad8f92SRahul Joshi 17904253320SMarcel Koester // Since there can be multiple `ReturnLike` terminators or others 18004253320SMarcel Koester // implementing the `RegionBranchTerminatorOpInterface`, all should have the 18104253320SMarcel Koester // same operand types when passing them to the same region. 182a3ad8f92SRahul Joshi 18304253320SMarcel Koester Optional<OperandRange> regionReturnOperands; 184a3ad8f92SRahul Joshi for (Block &block : region) { 185a3ad8f92SRahul Joshi Operation *terminator = block.getTerminator(); 18604253320SMarcel Koester auto terminatorOperands = 18704253320SMarcel Koester getRegionBranchSuccessorOperands(terminator, regionNo); 18804253320SMarcel Koester if (!terminatorOperands) 189a3ad8f92SRahul Joshi continue; 190a3ad8f92SRahul Joshi 19104253320SMarcel Koester if (!regionReturnOperands) { 19204253320SMarcel Koester regionReturnOperands = terminatorOperands; 193a3ad8f92SRahul Joshi continue; 194a3ad8f92SRahul Joshi } 195a3ad8f92SRahul Joshi 196a3ad8f92SRahul Joshi // Found more than one ReturnLike terminator. Make sure the operand types 197a3ad8f92SRahul Joshi // match with the first one. 19804253320SMarcel Koester if (regionReturnOperands->getTypes() != terminatorOperands->getTypes()) 199a3ad8f92SRahul Joshi return op->emitOpError("Region #") 200a3ad8f92SRahul Joshi << regionNo 201a3ad8f92SRahul Joshi << " operands mismatch between return-like terminators"; 202a3ad8f92SRahul Joshi } 203a3ad8f92SRahul Joshi 20479716559SAlex Zinenko auto inputTypesFromRegion = 20579716559SAlex Zinenko [&](Optional<unsigned> regionNo) -> Optional<TypeRange> { 20679716559SAlex Zinenko // If there is no return-like terminator, the op itself should verify 20779716559SAlex Zinenko // type consistency. 20804253320SMarcel Koester if (!regionReturnOperands) 20979716559SAlex Zinenko return llvm::None; 21079716559SAlex Zinenko 21104253320SMarcel Koester // All successors get the same set of operand types. 21204253320SMarcel Koester return TypeRange(regionReturnOperands->getTypes()); 213a3ad8f92SRahul Joshi }; 214a3ad8f92SRahul Joshi 215a3ad8f92SRahul Joshi if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesFromRegion))) 216a3ad8f92SRahul Joshi return failure(); 217a3ad8f92SRahul Joshi } 218a3ad8f92SRahul Joshi 219a3ad8f92SRahul Joshi return success(); 220a3ad8f92SRahul Joshi } 22104253320SMarcel Koester 222a5c2f782SMatthias Springer /// Return `true` if `a` and `b` are in mutually exclusive regions. 223a5c2f782SMatthias Springer /// 224a5c2f782SMatthias Springer /// 1. Find the first common of `a` and `b` (ancestor) that implements 225a5c2f782SMatthias Springer /// RegionBranchOpInterface. 226a5c2f782SMatthias Springer /// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are 227a5c2f782SMatthias Springer /// contained. 228a5c2f782SMatthias Springer /// 3. Check if `regionA` and `regionB` are mutually exclusive. They are 229a5c2f782SMatthias Springer /// mutually exclusive if they are not reachable from each other as per 230a5c2f782SMatthias Springer /// RegionBranchOpInterface::getSuccessorRegions. 231a5c2f782SMatthias Springer bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) { 232a5c2f782SMatthias Springer assert(a && "expected non-empty operation"); 233a5c2f782SMatthias Springer assert(b && "expected non-empty operation"); 234a5c2f782SMatthias Springer 235a5c2f782SMatthias Springer auto branchOp = a->getParentOfType<RegionBranchOpInterface>(); 236a5c2f782SMatthias Springer while (branchOp) { 237a5c2f782SMatthias Springer // Check if b is inside branchOp. (We already know that a is.) 238a5c2f782SMatthias Springer if (!branchOp->isProperAncestor(b)) { 239a5c2f782SMatthias Springer // Check next enclosing RegionBranchOpInterface. 240a5c2f782SMatthias Springer branchOp = branchOp->getParentOfType<RegionBranchOpInterface>(); 241a5c2f782SMatthias Springer continue; 242a5c2f782SMatthias Springer } 243a5c2f782SMatthias Springer 244a5c2f782SMatthias Springer // b is contained in branchOp. Retrieve the regions in which `a` and `b` 245a5c2f782SMatthias Springer // are contained. 246a5c2f782SMatthias Springer Region *regionA = nullptr, *regionB = nullptr; 247a5c2f782SMatthias Springer for (Region &r : branchOp->getRegions()) { 248a5c2f782SMatthias Springer if (r.findAncestorOpInRegion(*a)) { 249a5c2f782SMatthias Springer assert(!regionA && "already found a region for a"); 250a5c2f782SMatthias Springer regionA = &r; 251a5c2f782SMatthias Springer } 252a5c2f782SMatthias Springer if (r.findAncestorOpInRegion(*b)) { 253a5c2f782SMatthias Springer assert(!regionB && "already found a region for b"); 254a5c2f782SMatthias Springer regionB = &r; 255a5c2f782SMatthias Springer } 256a5c2f782SMatthias Springer } 257a5c2f782SMatthias Springer assert(regionA && regionB && "could not find region of op"); 258a5c2f782SMatthias Springer 259a5c2f782SMatthias Springer // Helper function that checks if region `r` is reachable from region 260a5c2f782SMatthias Springer // `begin`. 261a5c2f782SMatthias Springer std::function<bool(Region *, Region *)> isRegionReachable = 262a5c2f782SMatthias Springer [&](Region *begin, Region *r) { 263a5c2f782SMatthias Springer if (begin == r) 264a5c2f782SMatthias Springer return true; 265a5c2f782SMatthias Springer if (begin == nullptr) 266a5c2f782SMatthias Springer return false; 267a5c2f782SMatthias Springer // Compute index of region. 268a5c2f782SMatthias Springer int64_t beginIndex = -1; 269*e4853be2SMehdi Amini for (const auto &it : llvm::enumerate(branchOp->getRegions())) 270a5c2f782SMatthias Springer if (&it.value() == begin) 271a5c2f782SMatthias Springer beginIndex = it.index(); 272a5c2f782SMatthias Springer assert(beginIndex != -1 && "could not find region in op"); 273a5c2f782SMatthias Springer // Retrieve all successors of the region. 274a5c2f782SMatthias Springer SmallVector<RegionSuccessor> successors; 275a5c2f782SMatthias Springer branchOp.getSuccessorRegions(beginIndex, successors); 276a5c2f782SMatthias Springer // Call function recursively on all successors. 277a5c2f782SMatthias Springer for (RegionSuccessor successor : successors) 278a5c2f782SMatthias Springer if (isRegionReachable(successor.getSuccessor(), r)) 279a5c2f782SMatthias Springer return true; 280a5c2f782SMatthias Springer return false; 281a5c2f782SMatthias Springer }; 282a5c2f782SMatthias Springer 283a5c2f782SMatthias Springer // `a` and `b` are in mutually exclusive regions if neither region is 284a5c2f782SMatthias Springer // reachable from the other region. 285a5c2f782SMatthias Springer return !isRegionReachable(regionA, regionB) && 286a5c2f782SMatthias Springer !isRegionReachable(regionB, regionA); 287a5c2f782SMatthias Springer } 288a5c2f782SMatthias Springer 289a5c2f782SMatthias Springer // Could not find a common RegionBranchOpInterface among a's and b's 290a5c2f782SMatthias Springer // ancestors. 291a5c2f782SMatthias Springer return false; 292a5c2f782SMatthias Springer } 293a5c2f782SMatthias Springer 29404253320SMarcel Koester //===----------------------------------------------------------------------===// 29504253320SMarcel Koester // RegionBranchTerminatorOpInterface 29604253320SMarcel Koester //===----------------------------------------------------------------------===// 29704253320SMarcel Koester 29804253320SMarcel Koester /// Returns true if the given operation is either annotated with the 29904253320SMarcel Koester /// `ReturnLike` trait or implements the `RegionBranchTerminatorOpInterface`. 30004253320SMarcel Koester bool mlir::isRegionReturnLike(Operation *operation) { 30104253320SMarcel Koester return dyn_cast<RegionBranchTerminatorOpInterface>(operation) || 30204253320SMarcel Koester operation->hasTrait<OpTrait::ReturnLike>(); 30304253320SMarcel Koester } 30404253320SMarcel Koester 30504253320SMarcel Koester /// Returns the mutable operands that are passed to the region with the given 30604253320SMarcel Koester /// `regionIndex`. If the operation does not implement the 30704253320SMarcel Koester /// `RegionBranchTerminatorOpInterface` and is not marked as `ReturnLike`, the 30804253320SMarcel Koester /// result will be `llvm::None`. In all other cases, the resulting 30904253320SMarcel Koester /// `OperandRange` represents all operands that are passed to the specified 31004253320SMarcel Koester /// successor region. If `regionIndex` is `llvm::None`, all operands that are 31104253320SMarcel Koester /// passed to the parent operation will be returned. 31204253320SMarcel Koester Optional<MutableOperandRange> 31304253320SMarcel Koester mlir::getMutableRegionBranchSuccessorOperands(Operation *operation, 31404253320SMarcel Koester Optional<unsigned> regionIndex) { 31504253320SMarcel Koester // Try to query a RegionBranchTerminatorOpInterface to determine 31604253320SMarcel Koester // all successor operands that will be passed to the successor 31704253320SMarcel Koester // input arguments. 31804253320SMarcel Koester if (auto regionTerminatorInterface = 31904253320SMarcel Koester dyn_cast<RegionBranchTerminatorOpInterface>(operation)) 32004253320SMarcel Koester return regionTerminatorInterface.getMutableSuccessorOperands(regionIndex); 32104253320SMarcel Koester 32204253320SMarcel Koester // TODO: The ReturnLike trait should imply a default implementation of the 32304253320SMarcel Koester // RegionBranchTerminatorOpInterface. This would make this code significantly 32404253320SMarcel Koester // easier. Furthermore, this may even make this function obsolete. 32504253320SMarcel Koester if (operation->hasTrait<OpTrait::ReturnLike>()) 32604253320SMarcel Koester return MutableOperandRange(operation); 32704253320SMarcel Koester return llvm::None; 32804253320SMarcel Koester } 32904253320SMarcel Koester 33004253320SMarcel Koester /// Returns the read only operands that are passed to the region with the given 33104253320SMarcel Koester /// `regionIndex`. See `getMutableRegionBranchSuccessorOperands` for more 33204253320SMarcel Koester /// information. 33304253320SMarcel Koester Optional<OperandRange> 33404253320SMarcel Koester mlir::getRegionBranchSuccessorOperands(Operation *operation, 33504253320SMarcel Koester Optional<unsigned> regionIndex) { 33604253320SMarcel Koester auto range = getMutableRegionBranchSuccessorOperands(operation, regionIndex); 33704253320SMarcel Koester return range ? Optional<OperandRange>(*range) : llvm::None; 33804253320SMarcel Koester } 339