1 //===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Interfaces/ControlFlowInterfaces.h" 10 #include "mlir/IR/BuiltinTypes.h" 11 #include "llvm/ADT/SmallPtrSet.h" 12 13 using namespace mlir; 14 15 //===----------------------------------------------------------------------===// 16 // ControlFlowInterfaces 17 //===----------------------------------------------------------------------===// 18 19 #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc" 20 21 //===----------------------------------------------------------------------===// 22 // BranchOpInterface 23 //===----------------------------------------------------------------------===// 24 25 /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some 26 /// successor if 'operandIndex' is within the range of 'operands', or None if 27 /// `operandIndex` isn't a successor operand index. 28 Optional<BlockArgument> 29 detail::getBranchSuccessorArgument(Optional<OperandRange> operands, 30 unsigned operandIndex, Block *successor) { 31 // Check that the operands are valid. 32 if (!operands || operands->empty()) 33 return llvm::None; 34 35 // Check to ensure that this operand is within the range. 36 unsigned operandsStart = operands->getBeginOperandIndex(); 37 if (operandIndex < operandsStart || 38 operandIndex >= (operandsStart + operands->size())) 39 return llvm::None; 40 41 // Index the successor. 42 unsigned argIndex = operandIndex - operandsStart; 43 return successor->getArgument(argIndex); 44 } 45 46 /// Verify that the given operands match those of the given successor block. 47 LogicalResult 48 detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo, 49 Optional<OperandRange> operands) { 50 if (!operands) 51 return success(); 52 53 // Check the count. 54 unsigned operandCount = operands->size(); 55 Block *destBB = op->getSuccessor(succNo); 56 if (operandCount != destBB->getNumArguments()) 57 return op->emitError() << "branch has " << operandCount 58 << " operands for successor #" << succNo 59 << ", but target block has " 60 << destBB->getNumArguments(); 61 62 // Check the types. 63 auto operandIt = operands->begin(); 64 for (unsigned i = 0; i != operandCount; ++i, ++operandIt) { 65 if ((*operandIt).getType() != destBB->getArgument(i).getType()) 66 return op->emitError() << "type mismatch for bb argument #" << i 67 << " of successor #" << succNo; 68 } 69 return success(); 70 } 71 72 //===----------------------------------------------------------------------===// 73 // RegionBranchOpInterface 74 //===----------------------------------------------------------------------===// 75 76 // A constant value to represent unknown number of region invocations. 77 const int64_t mlir::kUnknownNumRegionInvocations = -1; 78 79 /// Verify that types match along all region control flow edges originating from 80 /// `sourceNo` (region # if source is a region, llvm::None if source is parent 81 /// op). `getInputsTypesForRegion` is a function that returns the types of the 82 /// inputs that flow from `sourceIndex' to the given region, or llvm::None if 83 /// the exact type match verification is not necessary (e.g., if the Op verifies 84 /// the match itself). 85 static LogicalResult 86 verifyTypesAlongAllEdges(Operation *op, Optional<unsigned> sourceNo, 87 function_ref<Optional<TypeRange>(Optional<unsigned>)> 88 getInputsTypesForRegion) { 89 auto regionInterface = cast<RegionBranchOpInterface>(op); 90 91 SmallVector<RegionSuccessor, 2> successors; 92 unsigned numInputs; 93 if (sourceNo) { 94 Region &srcRegion = op->getRegion(sourceNo.getValue()); 95 numInputs = srcRegion.getNumArguments(); 96 } else { 97 numInputs = op->getNumOperands(); 98 } 99 SmallVector<Attribute, 2> operands(numInputs, nullptr); 100 regionInterface.getSuccessorRegions(sourceNo, operands, successors); 101 102 for (RegionSuccessor &succ : successors) { 103 Optional<unsigned> succRegionNo; 104 if (!succ.isParent()) 105 succRegionNo = succ.getSuccessor()->getRegionNumber(); 106 107 auto printEdgeName = [&](InFlightDiagnostic &diag) -> InFlightDiagnostic & { 108 diag << "from "; 109 if (sourceNo) 110 diag << "Region #" << sourceNo.getValue(); 111 else 112 diag << "parent operands"; 113 114 diag << " to "; 115 if (succRegionNo) 116 diag << "Region #" << succRegionNo.getValue(); 117 else 118 diag << "parent results"; 119 return diag; 120 }; 121 122 Optional<TypeRange> sourceTypes = getInputsTypesForRegion(succRegionNo); 123 if (!sourceTypes.hasValue()) 124 continue; 125 126 TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes(); 127 if (sourceTypes->size() != succInputsTypes.size()) { 128 InFlightDiagnostic diag = op->emitOpError(" region control flow edge "); 129 return printEdgeName(diag) << ": source has " << sourceTypes->size() 130 << " operands, but target successor needs " 131 << succInputsTypes.size(); 132 } 133 134 for (auto typesIdx : 135 llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) { 136 Type sourceType = std::get<0>(typesIdx.value()); 137 Type inputType = std::get<1>(typesIdx.value()); 138 if (sourceType != inputType) { 139 InFlightDiagnostic diag = op->emitOpError(" along control flow edge "); 140 return printEdgeName(diag) 141 << ": source type #" << typesIdx.index() << " " << sourceType 142 << " should match input type #" << typesIdx.index() << " " 143 << inputType; 144 } 145 } 146 } 147 return success(); 148 } 149 150 /// Verify that types match along control flow edges described the given op. 151 LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) { 152 auto regionInterface = cast<RegionBranchOpInterface>(op); 153 154 auto inputTypesFromParent = [&](Optional<unsigned> regionNo) -> TypeRange { 155 if (regionNo.hasValue()) { 156 return regionInterface.getSuccessorEntryOperands(regionNo.getValue()) 157 .getTypes(); 158 } 159 160 // If the successor of a parent op is the parent itself 161 // RegionBranchOpInterface does not have an API to query what the entry 162 // operands will be in that case. Vend out the result types of the op in 163 // that case so that type checking succeeds for this case. 164 return op->getResultTypes(); 165 }; 166 167 // Verify types along control flow edges originating from the parent. 168 if (failed(verifyTypesAlongAllEdges(op, llvm::None, inputTypesFromParent))) 169 return failure(); 170 171 // RegionBranchOpInterface should not be implemented by Ops that do not have 172 // attached regions. 173 assert(op->getNumRegions() != 0); 174 175 // Verify types along control flow edges originating from each region. 176 for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) { 177 Region ®ion = op->getRegion(regionNo); 178 179 // Since there can be multiple `ReturnLike` terminators or others 180 // implementing the `RegionBranchTerminatorOpInterface`, all should have the 181 // same operand types when passing them to the same region. 182 183 Optional<OperandRange> regionReturnOperands; 184 for (Block &block : region) { 185 Operation *terminator = block.getTerminator(); 186 auto terminatorOperands = 187 getRegionBranchSuccessorOperands(terminator, regionNo); 188 if (!terminatorOperands) 189 continue; 190 191 if (!regionReturnOperands) { 192 regionReturnOperands = terminatorOperands; 193 continue; 194 } 195 196 // Found more than one ReturnLike terminator. Make sure the operand types 197 // match with the first one. 198 if (regionReturnOperands->getTypes() != terminatorOperands->getTypes()) 199 return op->emitOpError("Region #") 200 << regionNo 201 << " operands mismatch between return-like terminators"; 202 } 203 204 auto inputTypesFromRegion = 205 [&](Optional<unsigned> regionNo) -> Optional<TypeRange> { 206 // If there is no return-like terminator, the op itself should verify 207 // type consistency. 208 if (!regionReturnOperands) 209 return llvm::None; 210 211 // All successors get the same set of operand types. 212 return TypeRange(regionReturnOperands->getTypes()); 213 }; 214 215 if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesFromRegion))) 216 return failure(); 217 } 218 219 return success(); 220 } 221 222 //===----------------------------------------------------------------------===// 223 // RegionBranchTerminatorOpInterface 224 //===----------------------------------------------------------------------===// 225 226 /// Returns true if the given operation is either annotated with the 227 /// `ReturnLike` trait or implements the `RegionBranchTerminatorOpInterface`. 228 bool mlir::isRegionReturnLike(Operation *operation) { 229 return dyn_cast<RegionBranchTerminatorOpInterface>(operation) || 230 operation->hasTrait<OpTrait::ReturnLike>(); 231 } 232 233 /// Returns the mutable operands that are passed to the region with the given 234 /// `regionIndex`. If the operation does not implement the 235 /// `RegionBranchTerminatorOpInterface` and is not marked as `ReturnLike`, the 236 /// result will be `llvm::None`. In all other cases, the resulting 237 /// `OperandRange` represents all operands that are passed to the specified 238 /// successor region. If `regionIndex` is `llvm::None`, all operands that are 239 /// passed to the parent operation will be returned. 240 Optional<MutableOperandRange> 241 mlir::getMutableRegionBranchSuccessorOperands(Operation *operation, 242 Optional<unsigned> regionIndex) { 243 // Try to query a RegionBranchTerminatorOpInterface to determine 244 // all successor operands that will be passed to the successor 245 // input arguments. 246 if (auto regionTerminatorInterface = 247 dyn_cast<RegionBranchTerminatorOpInterface>(operation)) 248 return regionTerminatorInterface.getMutableSuccessorOperands(regionIndex); 249 250 // TODO: The ReturnLike trait should imply a default implementation of the 251 // RegionBranchTerminatorOpInterface. This would make this code significantly 252 // easier. Furthermore, this may even make this function obsolete. 253 if (operation->hasTrait<OpTrait::ReturnLike>()) 254 return MutableOperandRange(operation); 255 return llvm::None; 256 } 257 258 /// Returns the read only operands that are passed to the region with the given 259 /// `regionIndex`. See `getMutableRegionBranchSuccessorOperands` for more 260 /// information. 261 Optional<OperandRange> 262 mlir::getRegionBranchSuccessorOperands(Operation *operation, 263 Optional<unsigned> regionIndex) { 264 auto range = getMutableRegionBranchSuccessorOperands(operation, regionIndex); 265 return range ? Optional<OperandRange>(*range) : llvm::None; 266 } 267