17ce1e7abSRiver Riddle //===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===// 27ce1e7abSRiver Riddle // 37ce1e7abSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 47ce1e7abSRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 57ce1e7abSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 67ce1e7abSRiver Riddle // 77ce1e7abSRiver Riddle //===----------------------------------------------------------------------===// 87ce1e7abSRiver Riddle 97ce1e7abSRiver Riddle #include "mlir/Interfaces/ControlFlowInterfaces.h" 107ce1e7abSRiver Riddle #include "mlir/IR/StandardTypes.h" 11a3ad8f92SRahul Joshi #include "llvm/ADT/SmallPtrSet.h" 127ce1e7abSRiver Riddle 137ce1e7abSRiver Riddle using namespace mlir; 147ce1e7abSRiver Riddle 157ce1e7abSRiver Riddle //===----------------------------------------------------------------------===// 167ce1e7abSRiver Riddle // ControlFlowInterfaces 177ce1e7abSRiver Riddle //===----------------------------------------------------------------------===// 187ce1e7abSRiver Riddle 197ce1e7abSRiver Riddle #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc" 207ce1e7abSRiver Riddle 217ce1e7abSRiver Riddle //===----------------------------------------------------------------------===// 227ce1e7abSRiver Riddle // BranchOpInterface 237ce1e7abSRiver Riddle //===----------------------------------------------------------------------===// 247ce1e7abSRiver Riddle 257ce1e7abSRiver Riddle /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some 267ce1e7abSRiver Riddle /// successor if 'operandIndex' is within the range of 'operands', or None if 277ce1e7abSRiver Riddle /// `operandIndex` isn't a successor operand index. 28a3ad8f92SRahul Joshi Optional<BlockArgument> 29a3ad8f92SRahul Joshi detail::getBranchSuccessorArgument(Optional<OperandRange> operands, 30a3ad8f92SRahul Joshi unsigned operandIndex, Block *successor) { 317ce1e7abSRiver Riddle // Check that the operands are valid. 327ce1e7abSRiver Riddle if (!operands || operands->empty()) 337ce1e7abSRiver Riddle return llvm::None; 347ce1e7abSRiver Riddle 357ce1e7abSRiver Riddle // Check to ensure that this operand is within the range. 367ce1e7abSRiver Riddle unsigned operandsStart = operands->getBeginOperandIndex(); 377ce1e7abSRiver Riddle if (operandIndex < operandsStart || 387ce1e7abSRiver Riddle operandIndex >= (operandsStart + operands->size())) 397ce1e7abSRiver Riddle return llvm::None; 407ce1e7abSRiver Riddle 417ce1e7abSRiver Riddle // Index the successor. 427ce1e7abSRiver Riddle unsigned argIndex = operandIndex - operandsStart; 437ce1e7abSRiver Riddle return successor->getArgument(argIndex); 447ce1e7abSRiver Riddle } 457ce1e7abSRiver Riddle 467ce1e7abSRiver Riddle /// Verify that the given operands match those of the given successor block. 477ce1e7abSRiver Riddle LogicalResult 48a3ad8f92SRahul Joshi detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo, 497ce1e7abSRiver Riddle Optional<OperandRange> operands) { 507ce1e7abSRiver Riddle if (!operands) 517ce1e7abSRiver Riddle return success(); 527ce1e7abSRiver Riddle 537ce1e7abSRiver Riddle // Check the count. 547ce1e7abSRiver Riddle unsigned operandCount = operands->size(); 557ce1e7abSRiver Riddle Block *destBB = op->getSuccessor(succNo); 567ce1e7abSRiver Riddle if (operandCount != destBB->getNumArguments()) 577ce1e7abSRiver Riddle return op->emitError() << "branch has " << operandCount 587ce1e7abSRiver Riddle << " operands for successor #" << succNo 597ce1e7abSRiver Riddle << ", but target block has " 607ce1e7abSRiver Riddle << destBB->getNumArguments(); 617ce1e7abSRiver Riddle 627ce1e7abSRiver Riddle // Check the types. 637ce1e7abSRiver Riddle auto operandIt = operands->begin(); 647ce1e7abSRiver Riddle for (unsigned i = 0; i != operandCount; ++i, ++operandIt) { 657ce1e7abSRiver Riddle if ((*operandIt).getType() != destBB->getArgument(i).getType()) 667ce1e7abSRiver Riddle return op->emitError() << "type mismatch for bb argument #" << i 677ce1e7abSRiver Riddle << " of successor #" << succNo; 687ce1e7abSRiver Riddle } 697ce1e7abSRiver Riddle return success(); 707ce1e7abSRiver Riddle } 71a3ad8f92SRahul Joshi 72a3ad8f92SRahul Joshi //===----------------------------------------------------------------------===// 73a3ad8f92SRahul Joshi // RegionBranchOpInterface 74a3ad8f92SRahul Joshi //===----------------------------------------------------------------------===// 75a3ad8f92SRahul Joshi 76a3ad8f92SRahul Joshi /// Verify that types match along all region control flow edges originating from 77a3ad8f92SRahul Joshi /// `sourceNo` (region # if source is a region, llvm::None if source is parent 78a3ad8f92SRahul Joshi /// op). `getInputsTypesForRegion` is a function that returns the types of the 79a3ad8f92SRahul Joshi /// inputs that flow from `sourceIndex' to the given region. 80a3ad8f92SRahul Joshi static LogicalResult verifyTypesAlongAllEdges( 81a3ad8f92SRahul Joshi Operation *op, Optional<unsigned> sourceNo, 82a3ad8f92SRahul Joshi function_ref<TypeRange(Optional<unsigned>)> getInputsTypesForRegion) { 83a3ad8f92SRahul Joshi auto regionInterface = cast<RegionBranchOpInterface>(op); 84a3ad8f92SRahul Joshi 85a3ad8f92SRahul Joshi SmallVector<RegionSuccessor, 2> successors; 86a3ad8f92SRahul Joshi unsigned numInputs; 87a3ad8f92SRahul Joshi if (sourceNo) { 88a3ad8f92SRahul Joshi Region &srcRegion = op->getRegion(sourceNo.getValue()); 89a3ad8f92SRahul Joshi numInputs = srcRegion.getNumArguments(); 90a3ad8f92SRahul Joshi } else { 91a3ad8f92SRahul Joshi numInputs = op->getNumOperands(); 92a3ad8f92SRahul Joshi } 93a3ad8f92SRahul Joshi SmallVector<Attribute, 2> operands(numInputs, nullptr); 94a3ad8f92SRahul Joshi regionInterface.getSuccessorRegions(sourceNo, operands, successors); 95a3ad8f92SRahul Joshi 96a3ad8f92SRahul Joshi for (RegionSuccessor &succ : successors) { 97a3ad8f92SRahul Joshi Optional<unsigned> succRegionNo; 98a3ad8f92SRahul Joshi if (!succ.isParent()) 99a3ad8f92SRahul Joshi succRegionNo = succ.getSuccessor()->getRegionNumber(); 100a3ad8f92SRahul Joshi 101a3ad8f92SRahul Joshi auto printEdgeName = [&](InFlightDiagnostic &diag) -> InFlightDiagnostic & { 102a3ad8f92SRahul Joshi diag << "from "; 103a3ad8f92SRahul Joshi if (sourceNo) 104a3ad8f92SRahul Joshi diag << "Region #" << sourceNo.getValue(); 105a3ad8f92SRahul Joshi else 106be35264aSSean Silva diag << "parent operands"; 107a3ad8f92SRahul Joshi 108a3ad8f92SRahul Joshi diag << " to "; 109a3ad8f92SRahul Joshi if (succRegionNo) 110a3ad8f92SRahul Joshi diag << "Region #" << succRegionNo.getValue(); 111a3ad8f92SRahul Joshi else 112be35264aSSean Silva diag << "parent results"; 113a3ad8f92SRahul Joshi return diag; 114a3ad8f92SRahul Joshi }; 115a3ad8f92SRahul Joshi 116a3ad8f92SRahul Joshi TypeRange sourceTypes = getInputsTypesForRegion(succRegionNo); 117a3ad8f92SRahul Joshi TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes(); 118a3ad8f92SRahul Joshi if (sourceTypes.size() != succInputsTypes.size()) { 119a3ad8f92SRahul Joshi InFlightDiagnostic diag = op->emitOpError(" region control flow edge "); 120be35264aSSean Silva return printEdgeName(diag) << ": source has " << sourceTypes.size() 121be35264aSSean Silva << " operands, but target successor needs " 122a3ad8f92SRahul Joshi << succInputsTypes.size(); 123a3ad8f92SRahul Joshi } 124a3ad8f92SRahul Joshi 125a3ad8f92SRahul Joshi for (auto typesIdx : 126a3ad8f92SRahul Joshi llvm::enumerate(llvm::zip(sourceTypes, succInputsTypes))) { 127a3ad8f92SRahul Joshi Type sourceType = std::get<0>(typesIdx.value()); 128a3ad8f92SRahul Joshi Type inputType = std::get<1>(typesIdx.value()); 129a3ad8f92SRahul Joshi if (sourceType != inputType) { 130a3ad8f92SRahul Joshi InFlightDiagnostic diag = op->emitOpError(" along control flow edge "); 131a3ad8f92SRahul Joshi return printEdgeName(diag) 132be35264aSSean Silva << ": source type #" << typesIdx.index() << " " << sourceType 133be35264aSSean Silva << " should match input type #" << typesIdx.index() << " " 134a3ad8f92SRahul Joshi << inputType; 135a3ad8f92SRahul Joshi } 136a3ad8f92SRahul Joshi } 137a3ad8f92SRahul Joshi } 138a3ad8f92SRahul Joshi return success(); 139a3ad8f92SRahul Joshi } 140a3ad8f92SRahul Joshi 141a3ad8f92SRahul Joshi /// Verify that types match along control flow edges described the given op. 142a3ad8f92SRahul Joshi LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) { 143a3ad8f92SRahul Joshi auto regionInterface = cast<RegionBranchOpInterface>(op); 144a3ad8f92SRahul Joshi 145a3ad8f92SRahul Joshi auto inputTypesFromParent = [&](Optional<unsigned> regionNo) -> TypeRange { 146a3ad8f92SRahul Joshi if (regionNo.hasValue()) { 147a3ad8f92SRahul Joshi return regionInterface.getSuccessorEntryOperands(regionNo.getValue()) 148a3ad8f92SRahul Joshi .getTypes(); 149a3ad8f92SRahul Joshi } 150a3ad8f92SRahul Joshi 151a3ad8f92SRahul Joshi // If the successor of a parent op is the parent itself 152a3ad8f92SRahul Joshi // RegionBranchOpInterface does not have an API to query what the entry 153a3ad8f92SRahul Joshi // operands will be in that case. Vend out the result types of the op in 154a3ad8f92SRahul Joshi // that case so that type checking succeeds for this case. 155a3ad8f92SRahul Joshi return op->getResultTypes(); 156a3ad8f92SRahul Joshi }; 157a3ad8f92SRahul Joshi 158a3ad8f92SRahul Joshi // Verify types along control flow edges originating from the parent. 159a3ad8f92SRahul Joshi if (failed(verifyTypesAlongAllEdges(op, llvm::None, inputTypesFromParent))) 160a3ad8f92SRahul Joshi return failure(); 161a3ad8f92SRahul Joshi 162a3ad8f92SRahul Joshi // RegionBranchOpInterface should not be implemented by Ops that do not have 163a3ad8f92SRahul Joshi // attached regions. 164a3ad8f92SRahul Joshi assert(op->getNumRegions() != 0); 165a3ad8f92SRahul Joshi 166a3ad8f92SRahul Joshi // Verify types along control flow edges originating from each region. 167a3ad8f92SRahul Joshi for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) { 168a3ad8f92SRahul Joshi Region ®ion = op->getRegion(regionNo); 169a3ad8f92SRahul Joshi 170*41b09f4eSKazuaki Ishizaki // Since the interface cannot distinguish between different ReturnLike 171a3ad8f92SRahul Joshi // ops within the region branching to different successors, all ReturnLike 172a3ad8f92SRahul Joshi // ops in this region should have the same operand types. We will then use 173a3ad8f92SRahul Joshi // one of them as the representative for type matching. 174a3ad8f92SRahul Joshi 175a3ad8f92SRahul Joshi Operation *regionReturn = nullptr; 176a3ad8f92SRahul Joshi for (Block &block : region) { 177a3ad8f92SRahul Joshi Operation *terminator = block.getTerminator(); 178a3ad8f92SRahul Joshi if (!terminator->hasTrait<OpTrait::ReturnLike>()) 179a3ad8f92SRahul Joshi continue; 180a3ad8f92SRahul Joshi 181a3ad8f92SRahul Joshi if (!regionReturn) { 182a3ad8f92SRahul Joshi regionReturn = terminator; 183a3ad8f92SRahul Joshi continue; 184a3ad8f92SRahul Joshi } 185a3ad8f92SRahul Joshi 186a3ad8f92SRahul Joshi // Found more than one ReturnLike terminator. Make sure the operand types 187a3ad8f92SRahul Joshi // match with the first one. 188a3ad8f92SRahul Joshi if (regionReturn->getOperandTypes() != terminator->getOperandTypes()) 189a3ad8f92SRahul Joshi return op->emitOpError("Region #") 190a3ad8f92SRahul Joshi << regionNo 191a3ad8f92SRahul Joshi << " operands mismatch between return-like terminators"; 192a3ad8f92SRahul Joshi } 193a3ad8f92SRahul Joshi 194a3ad8f92SRahul Joshi auto inputTypesFromRegion = [&](Optional<unsigned> regionNo) -> TypeRange { 195a3ad8f92SRahul Joshi // All successors get the same set of operands. 196a3ad8f92SRahul Joshi return regionReturn ? TypeRange(regionReturn->getOperands().getTypes()) 197a3ad8f92SRahul Joshi : TypeRange(); 198a3ad8f92SRahul Joshi }; 199a3ad8f92SRahul Joshi 200a3ad8f92SRahul Joshi if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesFromRegion))) 201a3ad8f92SRahul Joshi return failure(); 202a3ad8f92SRahul Joshi } 203a3ad8f92SRahul Joshi 204a3ad8f92SRahul Joshi return success(); 205a3ad8f92SRahul Joshi } 206