1 //===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Interfaces/ControlFlowInterfaces.h" 10 #include "mlir/IR/BuiltinTypes.h" 11 #include "llvm/ADT/SmallPtrSet.h" 12 13 using namespace mlir; 14 15 //===----------------------------------------------------------------------===// 16 // ControlFlowInterfaces 17 //===----------------------------------------------------------------------===// 18 19 #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc" 20 21 //===----------------------------------------------------------------------===// 22 // BranchOpInterface 23 //===----------------------------------------------------------------------===// 24 25 /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some 26 /// successor if 'operandIndex' is within the range of 'operands', or None if 27 /// `operandIndex` isn't a successor operand index. 28 Optional<BlockArgument> 29 detail::getBranchSuccessorArgument(Optional<OperandRange> operands, 30 unsigned operandIndex, Block *successor) { 31 // Check that the operands are valid. 32 if (!operands || operands->empty()) 33 return llvm::None; 34 35 // Check to ensure that this operand is within the range. 36 unsigned operandsStart = operands->getBeginOperandIndex(); 37 if (operandIndex < operandsStart || 38 operandIndex >= (operandsStart + operands->size())) 39 return llvm::None; 40 41 // Index the successor. 42 unsigned argIndex = operandIndex - operandsStart; 43 return successor->getArgument(argIndex); 44 } 45 46 /// Verify that the given operands match those of the given successor block. 47 LogicalResult 48 detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo, 49 Optional<OperandRange> operands) { 50 if (!operands) 51 return success(); 52 53 // Check the count. 54 unsigned operandCount = operands->size(); 55 Block *destBB = op->getSuccessor(succNo); 56 if (operandCount != destBB->getNumArguments()) 57 return op->emitError() << "branch has " << operandCount 58 << " operands for successor #" << succNo 59 << ", but target block has " 60 << destBB->getNumArguments(); 61 62 // Check the types. 63 auto operandIt = operands->begin(); 64 for (unsigned i = 0; i != operandCount; ++i, ++operandIt) { 65 if ((*operandIt).getType() != destBB->getArgument(i).getType()) 66 return op->emitError() << "type mismatch for bb argument #" << i 67 << " of successor #" << succNo; 68 } 69 return success(); 70 } 71 72 //===----------------------------------------------------------------------===// 73 // RegionBranchOpInterface 74 //===----------------------------------------------------------------------===// 75 76 /// Verify that types match along all region control flow edges originating from 77 /// `sourceNo` (region # if source is a region, llvm::None if source is parent 78 /// op). `getInputsTypesForRegion` is a function that returns the types of the 79 /// inputs that flow from `sourceIndex' to the given region, or llvm::None if 80 /// the exact type match verification is not necessary (e.g., if the Op verifies 81 /// the match itself). 82 static LogicalResult 83 verifyTypesAlongAllEdges(Operation *op, Optional<unsigned> sourceNo, 84 function_ref<Optional<TypeRange>(Optional<unsigned>)> 85 getInputsTypesForRegion) { 86 auto regionInterface = cast<RegionBranchOpInterface>(op); 87 88 SmallVector<RegionSuccessor, 2> successors; 89 unsigned numInputs; 90 if (sourceNo) { 91 Region &srcRegion = op->getRegion(sourceNo.getValue()); 92 numInputs = srcRegion.getNumArguments(); 93 } else { 94 numInputs = op->getNumOperands(); 95 } 96 SmallVector<Attribute, 2> operands(numInputs, nullptr); 97 regionInterface.getSuccessorRegions(sourceNo, operands, successors); 98 99 for (RegionSuccessor &succ : successors) { 100 Optional<unsigned> succRegionNo; 101 if (!succ.isParent()) 102 succRegionNo = succ.getSuccessor()->getRegionNumber(); 103 104 auto printEdgeName = [&](InFlightDiagnostic &diag) -> InFlightDiagnostic & { 105 diag << "from "; 106 if (sourceNo) 107 diag << "Region #" << sourceNo.getValue(); 108 else 109 diag << "parent operands"; 110 111 diag << " to "; 112 if (succRegionNo) 113 diag << "Region #" << succRegionNo.getValue(); 114 else 115 diag << "parent results"; 116 return diag; 117 }; 118 119 Optional<TypeRange> sourceTypes = getInputsTypesForRegion(succRegionNo); 120 if (!sourceTypes.hasValue()) 121 continue; 122 123 TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes(); 124 if (sourceTypes->size() != succInputsTypes.size()) { 125 InFlightDiagnostic diag = op->emitOpError(" region control flow edge "); 126 return printEdgeName(diag) << ": source has " << sourceTypes->size() 127 << " operands, but target successor needs " 128 << succInputsTypes.size(); 129 } 130 131 for (const auto &typesIdx : 132 llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) { 133 Type sourceType = std::get<0>(typesIdx.value()); 134 Type inputType = std::get<1>(typesIdx.value()); 135 if (sourceType != inputType) { 136 InFlightDiagnostic diag = op->emitOpError(" along control flow edge "); 137 return printEdgeName(diag) 138 << ": source type #" << typesIdx.index() << " " << sourceType 139 << " should match input type #" << typesIdx.index() << " " 140 << inputType; 141 } 142 } 143 } 144 return success(); 145 } 146 147 /// Verify that types match along control flow edges described the given op. 148 LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) { 149 auto regionInterface = cast<RegionBranchOpInterface>(op); 150 151 auto inputTypesFromParent = [&](Optional<unsigned> regionNo) -> TypeRange { 152 if (regionNo.hasValue()) { 153 return regionInterface.getSuccessorEntryOperands(regionNo.getValue()) 154 .getTypes(); 155 } 156 157 // If the successor of a parent op is the parent itself 158 // RegionBranchOpInterface does not have an API to query what the entry 159 // operands will be in that case. Vend out the result types of the op in 160 // that case so that type checking succeeds for this case. 161 return op->getResultTypes(); 162 }; 163 164 // Verify types along control flow edges originating from the parent. 165 if (failed(verifyTypesAlongAllEdges(op, llvm::None, inputTypesFromParent))) 166 return failure(); 167 168 // RegionBranchOpInterface should not be implemented by Ops that do not have 169 // attached regions. 170 assert(op->getNumRegions() != 0); 171 172 // Verify types along control flow edges originating from each region. 173 for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) { 174 Region ®ion = op->getRegion(regionNo); 175 176 // Since there can be multiple `ReturnLike` terminators or others 177 // implementing the `RegionBranchTerminatorOpInterface`, all should have the 178 // same operand types when passing them to the same region. 179 180 Optional<OperandRange> regionReturnOperands; 181 for (Block &block : region) { 182 Operation *terminator = block.getTerminator(); 183 auto terminatorOperands = 184 getRegionBranchSuccessorOperands(terminator, regionNo); 185 if (!terminatorOperands) 186 continue; 187 188 if (!regionReturnOperands) { 189 regionReturnOperands = terminatorOperands; 190 continue; 191 } 192 193 // Found more than one ReturnLike terminator. Make sure the operand types 194 // match with the first one. 195 if (regionReturnOperands->getTypes() != terminatorOperands->getTypes()) 196 return op->emitOpError("Region #") 197 << regionNo 198 << " operands mismatch between return-like terminators"; 199 } 200 201 auto inputTypesFromRegion = 202 [&](Optional<unsigned> regionNo) -> Optional<TypeRange> { 203 // If there is no return-like terminator, the op itself should verify 204 // type consistency. 205 if (!regionReturnOperands) 206 return llvm::None; 207 208 // All successors get the same set of operand types. 209 return TypeRange(regionReturnOperands->getTypes()); 210 }; 211 212 if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesFromRegion))) 213 return failure(); 214 } 215 216 return success(); 217 } 218 219 /// Return `true` if `a` and `b` are in mutually exclusive regions. 220 /// 221 /// 1. Find the first common of `a` and `b` (ancestor) that implements 222 /// RegionBranchOpInterface. 223 /// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are 224 /// contained. 225 /// 3. Check if `regionA` and `regionB` are mutually exclusive. They are 226 /// mutually exclusive if they are not reachable from each other as per 227 /// RegionBranchOpInterface::getSuccessorRegions. 228 bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) { 229 assert(a && "expected non-empty operation"); 230 assert(b && "expected non-empty operation"); 231 232 auto branchOp = a->getParentOfType<RegionBranchOpInterface>(); 233 while (branchOp) { 234 // Check if b is inside branchOp. (We already know that a is.) 235 if (!branchOp->isProperAncestor(b)) { 236 // Check next enclosing RegionBranchOpInterface. 237 branchOp = branchOp->getParentOfType<RegionBranchOpInterface>(); 238 continue; 239 } 240 241 // b is contained in branchOp. Retrieve the regions in which `a` and `b` 242 // are contained. 243 Region *regionA = nullptr, *regionB = nullptr; 244 for (Region &r : branchOp->getRegions()) { 245 if (r.findAncestorOpInRegion(*a)) { 246 assert(!regionA && "already found a region for a"); 247 regionA = &r; 248 } 249 if (r.findAncestorOpInRegion(*b)) { 250 assert(!regionB && "already found a region for b"); 251 regionB = &r; 252 } 253 } 254 assert(regionA && regionB && "could not find region of op"); 255 256 // Helper function that checks if region `r` is reachable from region 257 // `begin`. 258 std::function<bool(Region *, Region *)> isRegionReachable = 259 [&](Region *begin, Region *r) { 260 if (begin == r) 261 return true; 262 if (begin == nullptr) 263 return false; 264 // Compute index of region. 265 int64_t beginIndex = -1; 266 for (const auto &it : llvm::enumerate(branchOp->getRegions())) 267 if (&it.value() == begin) 268 beginIndex = it.index(); 269 assert(beginIndex != -1 && "could not find region in op"); 270 // Retrieve all successors of the region. 271 SmallVector<RegionSuccessor> successors; 272 branchOp.getSuccessorRegions(beginIndex, successors); 273 // Call function recursively on all successors. 274 for (RegionSuccessor successor : successors) 275 if (isRegionReachable(successor.getSuccessor(), r)) 276 return true; 277 return false; 278 }; 279 280 // `a` and `b` are in mutually exclusive regions if neither region is 281 // reachable from the other region. 282 return !isRegionReachable(regionA, regionB) && 283 !isRegionReachable(regionB, regionA); 284 } 285 286 // Could not find a common RegionBranchOpInterface among a's and b's 287 // ancestors. 288 return false; 289 } 290 291 //===----------------------------------------------------------------------===// 292 // RegionBranchTerminatorOpInterface 293 //===----------------------------------------------------------------------===// 294 295 /// Returns true if the given operation is either annotated with the 296 /// `ReturnLike` trait or implements the `RegionBranchTerminatorOpInterface`. 297 bool mlir::isRegionReturnLike(Operation *operation) { 298 return dyn_cast<RegionBranchTerminatorOpInterface>(operation) || 299 operation->hasTrait<OpTrait::ReturnLike>(); 300 } 301 302 /// Returns the mutable operands that are passed to the region with the given 303 /// `regionIndex`. If the operation does not implement the 304 /// `RegionBranchTerminatorOpInterface` and is not marked as `ReturnLike`, the 305 /// result will be `llvm::None`. In all other cases, the resulting 306 /// `OperandRange` represents all operands that are passed to the specified 307 /// successor region. If `regionIndex` is `llvm::None`, all operands that are 308 /// passed to the parent operation will be returned. 309 Optional<MutableOperandRange> 310 mlir::getMutableRegionBranchSuccessorOperands(Operation *operation, 311 Optional<unsigned> regionIndex) { 312 // Try to query a RegionBranchTerminatorOpInterface to determine 313 // all successor operands that will be passed to the successor 314 // input arguments. 315 if (auto regionTerminatorInterface = 316 dyn_cast<RegionBranchTerminatorOpInterface>(operation)) 317 return regionTerminatorInterface.getMutableSuccessorOperands(regionIndex); 318 319 // TODO: The ReturnLike trait should imply a default implementation of the 320 // RegionBranchTerminatorOpInterface. This would make this code significantly 321 // easier. Furthermore, this may even make this function obsolete. 322 if (operation->hasTrait<OpTrait::ReturnLike>()) 323 return MutableOperandRange(operation); 324 return llvm::None; 325 } 326 327 /// Returns the read only operands that are passed to the region with the given 328 /// `regionIndex`. See `getMutableRegionBranchSuccessorOperands` for more 329 /// information. 330 Optional<OperandRange> 331 mlir::getRegionBranchSuccessorOperands(Operation *operation, 332 Optional<unsigned> regionIndex) { 333 auto range = getMutableRegionBranchSuccessorOperands(operation, regionIndex); 334 return range ? Optional<OperandRange>(*range) : llvm::None; 335 } 336