1 //===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Interfaces/ControlFlowInterfaces.h" 10 #include "mlir/IR/StandardTypes.h" 11 #include "llvm/ADT/SmallPtrSet.h" 12 13 using namespace mlir; 14 15 //===----------------------------------------------------------------------===// 16 // ControlFlowInterfaces 17 //===----------------------------------------------------------------------===// 18 19 #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc" 20 21 //===----------------------------------------------------------------------===// 22 // BranchOpInterface 23 //===----------------------------------------------------------------------===// 24 25 /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some 26 /// successor if 'operandIndex' is within the range of 'operands', or None if 27 /// `operandIndex` isn't a successor operand index. 28 Optional<BlockArgument> 29 detail::getBranchSuccessorArgument(Optional<OperandRange> operands, 30 unsigned operandIndex, Block *successor) { 31 // Check that the operands are valid. 32 if (!operands || operands->empty()) 33 return llvm::None; 34 35 // Check to ensure that this operand is within the range. 36 unsigned operandsStart = operands->getBeginOperandIndex(); 37 if (operandIndex < operandsStart || 38 operandIndex >= (operandsStart + operands->size())) 39 return llvm::None; 40 41 // Index the successor. 42 unsigned argIndex = operandIndex - operandsStart; 43 return successor->getArgument(argIndex); 44 } 45 46 /// Verify that the given operands match those of the given successor block. 47 LogicalResult 48 detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo, 49 Optional<OperandRange> operands) { 50 if (!operands) 51 return success(); 52 53 // Check the count. 54 unsigned operandCount = operands->size(); 55 Block *destBB = op->getSuccessor(succNo); 56 if (operandCount != destBB->getNumArguments()) 57 return op->emitError() << "branch has " << operandCount 58 << " operands for successor #" << succNo 59 << ", but target block has " 60 << destBB->getNumArguments(); 61 62 // Check the types. 63 auto operandIt = operands->begin(); 64 for (unsigned i = 0; i != operandCount; ++i, ++operandIt) { 65 if ((*operandIt).getType() != destBB->getArgument(i).getType()) 66 return op->emitError() << "type mismatch for bb argument #" << i 67 << " of successor #" << succNo; 68 } 69 return success(); 70 } 71 72 //===----------------------------------------------------------------------===// 73 // RegionBranchOpInterface 74 //===----------------------------------------------------------------------===// 75 76 /// Verify that types match along all region control flow edges originating from 77 /// `sourceNo` (region # if source is a region, llvm::None if source is parent 78 /// op). `getInputsTypesForRegion` is a function that returns the types of the 79 /// inputs that flow from `sourceIndex' to the given region. 80 static LogicalResult verifyTypesAlongAllEdges( 81 Operation *op, Optional<unsigned> sourceNo, 82 function_ref<TypeRange(Optional<unsigned>)> getInputsTypesForRegion) { 83 auto regionInterface = cast<RegionBranchOpInterface>(op); 84 85 SmallVector<RegionSuccessor, 2> successors; 86 unsigned numInputs; 87 if (sourceNo) { 88 Region &srcRegion = op->getRegion(sourceNo.getValue()); 89 numInputs = srcRegion.getNumArguments(); 90 } else { 91 numInputs = op->getNumOperands(); 92 } 93 SmallVector<Attribute, 2> operands(numInputs, nullptr); 94 regionInterface.getSuccessorRegions(sourceNo, operands, successors); 95 96 for (RegionSuccessor &succ : successors) { 97 Optional<unsigned> succRegionNo; 98 if (!succ.isParent()) 99 succRegionNo = succ.getSuccessor()->getRegionNumber(); 100 101 auto printEdgeName = [&](InFlightDiagnostic &diag) -> InFlightDiagnostic & { 102 diag << "from "; 103 if (sourceNo) 104 diag << "Region #" << sourceNo.getValue(); 105 else 106 diag << "parent operands"; 107 108 diag << " to "; 109 if (succRegionNo) 110 diag << "Region #" << succRegionNo.getValue(); 111 else 112 diag << "parent results"; 113 return diag; 114 }; 115 116 TypeRange sourceTypes = getInputsTypesForRegion(succRegionNo); 117 TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes(); 118 if (sourceTypes.size() != succInputsTypes.size()) { 119 InFlightDiagnostic diag = op->emitOpError(" region control flow edge "); 120 return printEdgeName(diag) << ": source has " << sourceTypes.size() 121 << " operands, but target successor needs " 122 << succInputsTypes.size(); 123 } 124 125 for (auto typesIdx : 126 llvm::enumerate(llvm::zip(sourceTypes, succInputsTypes))) { 127 Type sourceType = std::get<0>(typesIdx.value()); 128 Type inputType = std::get<1>(typesIdx.value()); 129 if (sourceType != inputType) { 130 InFlightDiagnostic diag = op->emitOpError(" along control flow edge "); 131 return printEdgeName(diag) 132 << ": source type #" << typesIdx.index() << " " << sourceType 133 << " should match input type #" << typesIdx.index() << " " 134 << inputType; 135 } 136 } 137 } 138 return success(); 139 } 140 141 /// Verify that types match along control flow edges described the given op. 142 LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) { 143 auto regionInterface = cast<RegionBranchOpInterface>(op); 144 145 auto inputTypesFromParent = [&](Optional<unsigned> regionNo) -> TypeRange { 146 if (regionNo.hasValue()) { 147 return regionInterface.getSuccessorEntryOperands(regionNo.getValue()) 148 .getTypes(); 149 } 150 151 // If the successor of a parent op is the parent itself 152 // RegionBranchOpInterface does not have an API to query what the entry 153 // operands will be in that case. Vend out the result types of the op in 154 // that case so that type checking succeeds for this case. 155 return op->getResultTypes(); 156 }; 157 158 // Verify types along control flow edges originating from the parent. 159 if (failed(verifyTypesAlongAllEdges(op, llvm::None, inputTypesFromParent))) 160 return failure(); 161 162 // RegionBranchOpInterface should not be implemented by Ops that do not have 163 // attached regions. 164 assert(op->getNumRegions() != 0); 165 166 // Verify types along control flow edges originating from each region. 167 for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) { 168 Region ®ion = op->getRegion(regionNo); 169 170 // Since the interface cannnot distinguish between different ReturnLike 171 // ops within the region branching to different successors, all ReturnLike 172 // ops in this region should have the same operand types. We will then use 173 // one of them as the representative for type matching. 174 175 Operation *regionReturn = nullptr; 176 for (Block &block : region) { 177 Operation *terminator = block.getTerminator(); 178 if (!terminator->hasTrait<OpTrait::ReturnLike>()) 179 continue; 180 181 if (!regionReturn) { 182 regionReturn = terminator; 183 continue; 184 } 185 186 // Found more than one ReturnLike terminator. Make sure the operand types 187 // match with the first one. 188 if (regionReturn->getOperandTypes() != terminator->getOperandTypes()) 189 return op->emitOpError("Region #") 190 << regionNo 191 << " operands mismatch between return-like terminators"; 192 } 193 194 auto inputTypesFromRegion = [&](Optional<unsigned> regionNo) -> TypeRange { 195 // All successors get the same set of operands. 196 return regionReturn ? TypeRange(regionReturn->getOperands().getTypes()) 197 : TypeRange(); 198 }; 199 200 if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesFromRegion))) 201 return failure(); 202 } 203 204 return success(); 205 } 206