1 //===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Interfaces/ControlFlowInterfaces.h" 10 #include "mlir/IR/BuiltinTypes.h" 11 #include "llvm/ADT/SmallPtrSet.h" 12 13 using namespace mlir; 14 15 //===----------------------------------------------------------------------===// 16 // ControlFlowInterfaces 17 //===----------------------------------------------------------------------===// 18 19 #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc" 20 21 //===----------------------------------------------------------------------===// 22 // BranchOpInterface 23 //===----------------------------------------------------------------------===// 24 25 /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some 26 /// successor if 'operandIndex' is within the range of 'operands', or None if 27 /// `operandIndex` isn't a successor operand index. 28 Optional<BlockArgument> 29 detail::getBranchSuccessorArgument(Optional<OperandRange> operands, 30 unsigned operandIndex, Block *successor) { 31 // Check that the operands are valid. 32 if (!operands || operands->empty()) 33 return llvm::None; 34 35 // Check to ensure that this operand is within the range. 36 unsigned operandsStart = operands->getBeginOperandIndex(); 37 if (operandIndex < operandsStart || 38 operandIndex >= (operandsStart + operands->size())) 39 return llvm::None; 40 41 // Index the successor. 42 unsigned argIndex = operandIndex - operandsStart; 43 return successor->getArgument(argIndex); 44 } 45 46 /// Verify that the given operands match those of the given successor block. 47 LogicalResult 48 detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo, 49 Optional<OperandRange> operands) { 50 if (!operands) 51 return success(); 52 53 // Check the count. 54 unsigned operandCount = operands->size(); 55 Block *destBB = op->getSuccessor(succNo); 56 if (operandCount != destBB->getNumArguments()) 57 return op->emitError() << "branch has " << operandCount 58 << " operands for successor #" << succNo 59 << ", but target block has " 60 << destBB->getNumArguments(); 61 62 // Check the types. 63 auto operandIt = operands->begin(); 64 for (unsigned i = 0; i != operandCount; ++i, ++operandIt) { 65 if (!cast<BranchOpInterface>(op).areTypesCompatible( 66 (*operandIt).getType(), destBB->getArgument(i).getType())) 67 return op->emitError() << "type mismatch for bb argument #" << i 68 << " of successor #" << succNo; 69 } 70 return success(); 71 } 72 73 //===----------------------------------------------------------------------===// 74 // RegionBranchOpInterface 75 //===----------------------------------------------------------------------===// 76 77 /// Verify that types match along all region control flow edges originating from 78 /// `sourceNo` (region # if source is a region, llvm::None if source is parent 79 /// op). `getInputsTypesForRegion` is a function that returns the types of the 80 /// inputs that flow from `sourceIndex' to the given region, or llvm::None if 81 /// the exact type match verification is not necessary (e.g., if the Op verifies 82 /// the match itself). 83 static LogicalResult 84 verifyTypesAlongAllEdges(Operation *op, Optional<unsigned> sourceNo, 85 function_ref<Optional<TypeRange>(Optional<unsigned>)> 86 getInputsTypesForRegion) { 87 auto regionInterface = cast<RegionBranchOpInterface>(op); 88 89 SmallVector<RegionSuccessor, 2> successors; 90 unsigned numInputs; 91 if (sourceNo) { 92 Region &srcRegion = op->getRegion(sourceNo.getValue()); 93 numInputs = srcRegion.getNumArguments(); 94 } else { 95 numInputs = op->getNumOperands(); 96 } 97 SmallVector<Attribute, 2> operands(numInputs, nullptr); 98 regionInterface.getSuccessorRegions(sourceNo, operands, successors); 99 100 for (RegionSuccessor &succ : successors) { 101 Optional<unsigned> succRegionNo; 102 if (!succ.isParent()) 103 succRegionNo = succ.getSuccessor()->getRegionNumber(); 104 105 auto printEdgeName = [&](InFlightDiagnostic &diag) -> InFlightDiagnostic & { 106 diag << "from "; 107 if (sourceNo) 108 diag << "Region #" << sourceNo.getValue(); 109 else 110 diag << "parent operands"; 111 112 diag << " to "; 113 if (succRegionNo) 114 diag << "Region #" << succRegionNo.getValue(); 115 else 116 diag << "parent results"; 117 return diag; 118 }; 119 120 Optional<TypeRange> sourceTypes = getInputsTypesForRegion(succRegionNo); 121 if (!sourceTypes.hasValue()) 122 continue; 123 124 TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes(); 125 if (sourceTypes->size() != succInputsTypes.size()) { 126 InFlightDiagnostic diag = op->emitOpError(" region control flow edge "); 127 return printEdgeName(diag) << ": source has " << sourceTypes->size() 128 << " operands, but target successor needs " 129 << succInputsTypes.size(); 130 } 131 132 for (const auto &typesIdx : 133 llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) { 134 Type sourceType = std::get<0>(typesIdx.value()); 135 Type inputType = std::get<1>(typesIdx.value()); 136 if (!regionInterface.areTypesCompatible(sourceType, inputType)) { 137 InFlightDiagnostic diag = op->emitOpError(" along control flow edge "); 138 return printEdgeName(diag) 139 << ": source type #" << typesIdx.index() << " " << sourceType 140 << " should match input type #" << typesIdx.index() << " " 141 << inputType; 142 } 143 } 144 } 145 return success(); 146 } 147 148 /// Verify that types match along control flow edges described the given op. 149 LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) { 150 auto regionInterface = cast<RegionBranchOpInterface>(op); 151 152 auto inputTypesFromParent = [&](Optional<unsigned> regionNo) -> TypeRange { 153 if (regionNo.hasValue()) { 154 return regionInterface.getSuccessorEntryOperands(regionNo.getValue()) 155 .getTypes(); 156 } 157 158 // If the successor of a parent op is the parent itself 159 // RegionBranchOpInterface does not have an API to query what the entry 160 // operands will be in that case. Vend out the result types of the op in 161 // that case so that type checking succeeds for this case. 162 return op->getResultTypes(); 163 }; 164 165 // Verify types along control flow edges originating from the parent. 166 if (failed(verifyTypesAlongAllEdges(op, llvm::None, inputTypesFromParent))) 167 return failure(); 168 169 // RegionBranchOpInterface should not be implemented by Ops that do not have 170 // attached regions. 171 assert(op->getNumRegions() != 0); 172 173 auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) { 174 if (lhs.size() != rhs.size()) 175 return false; 176 for (auto types : llvm::zip(lhs, rhs)) { 177 if (!regionInterface.areTypesCompatible(std::get<0>(types), 178 std::get<1>(types))) { 179 return false; 180 } 181 } 182 return true; 183 }; 184 185 // Verify types along control flow edges originating from each region. 186 for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) { 187 Region ®ion = op->getRegion(regionNo); 188 189 // Since there can be multiple `ReturnLike` terminators or others 190 // implementing the `RegionBranchTerminatorOpInterface`, all should have the 191 // same operand types when passing them to the same region. 192 193 Optional<OperandRange> regionReturnOperands; 194 for (Block &block : region) { 195 Operation *terminator = block.getTerminator(); 196 auto terminatorOperands = 197 getRegionBranchSuccessorOperands(terminator, regionNo); 198 if (!terminatorOperands) 199 continue; 200 201 if (!regionReturnOperands) { 202 regionReturnOperands = terminatorOperands; 203 continue; 204 } 205 206 // Found more than one ReturnLike terminator. Make sure the operand types 207 // match with the first one. 208 if (!areTypesCompatible(regionReturnOperands->getTypes(), 209 terminatorOperands->getTypes())) 210 return op->emitOpError("Region #") 211 << regionNo 212 << " operands mismatch between return-like terminators"; 213 } 214 215 auto inputTypesFromRegion = 216 [&](Optional<unsigned> regionNo) -> Optional<TypeRange> { 217 // If there is no return-like terminator, the op itself should verify 218 // type consistency. 219 if (!regionReturnOperands) 220 return llvm::None; 221 222 // All successors get the same set of operand types. 223 return TypeRange(regionReturnOperands->getTypes()); 224 }; 225 226 if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesFromRegion))) 227 return failure(); 228 } 229 230 return success(); 231 } 232 233 /// Return `true` if `a` and `b` are in mutually exclusive regions. 234 /// 235 /// 1. Find the first common of `a` and `b` (ancestor) that implements 236 /// RegionBranchOpInterface. 237 /// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are 238 /// contained. 239 /// 3. Check if `regionA` and `regionB` are mutually exclusive. They are 240 /// mutually exclusive if they are not reachable from each other as per 241 /// RegionBranchOpInterface::getSuccessorRegions. 242 bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) { 243 assert(a && "expected non-empty operation"); 244 assert(b && "expected non-empty operation"); 245 246 auto branchOp = a->getParentOfType<RegionBranchOpInterface>(); 247 while (branchOp) { 248 // Check if b is inside branchOp. (We already know that a is.) 249 if (!branchOp->isProperAncestor(b)) { 250 // Check next enclosing RegionBranchOpInterface. 251 branchOp = branchOp->getParentOfType<RegionBranchOpInterface>(); 252 continue; 253 } 254 255 // b is contained in branchOp. Retrieve the regions in which `a` and `b` 256 // are contained. 257 Region *regionA = nullptr, *regionB = nullptr; 258 for (Region &r : branchOp->getRegions()) { 259 if (r.findAncestorOpInRegion(*a)) { 260 assert(!regionA && "already found a region for a"); 261 regionA = &r; 262 } 263 if (r.findAncestorOpInRegion(*b)) { 264 assert(!regionB && "already found a region for b"); 265 regionB = &r; 266 } 267 } 268 assert(regionA && regionB && "could not find region of op"); 269 270 // Helper function that checks if region `r` is reachable from region 271 // `begin`. 272 std::function<bool(Region *, Region *)> isRegionReachable = 273 [&](Region *begin, Region *r) { 274 if (begin == r) 275 return true; 276 if (begin == nullptr) 277 return false; 278 // Compute index of region. 279 int64_t beginIndex = -1; 280 for (const auto &it : llvm::enumerate(branchOp->getRegions())) 281 if (&it.value() == begin) 282 beginIndex = it.index(); 283 assert(beginIndex != -1 && "could not find region in op"); 284 // Retrieve all successors of the region. 285 SmallVector<RegionSuccessor> successors; 286 branchOp.getSuccessorRegions(beginIndex, successors); 287 // Call function recursively on all successors. 288 for (RegionSuccessor successor : successors) 289 if (isRegionReachable(successor.getSuccessor(), r)) 290 return true; 291 return false; 292 }; 293 294 // `a` and `b` are in mutually exclusive regions if neither region is 295 // reachable from the other region. 296 return !isRegionReachable(regionA, regionB) && 297 !isRegionReachable(regionB, regionA); 298 } 299 300 // Could not find a common RegionBranchOpInterface among a's and b's 301 // ancestors. 302 return false; 303 } 304 305 //===----------------------------------------------------------------------===// 306 // RegionBranchTerminatorOpInterface 307 //===----------------------------------------------------------------------===// 308 309 /// Returns true if the given operation is either annotated with the 310 /// `ReturnLike` trait or implements the `RegionBranchTerminatorOpInterface`. 311 bool mlir::isRegionReturnLike(Operation *operation) { 312 return dyn_cast<RegionBranchTerminatorOpInterface>(operation) || 313 operation->hasTrait<OpTrait::ReturnLike>(); 314 } 315 316 /// Returns the mutable operands that are passed to the region with the given 317 /// `regionIndex`. If the operation does not implement the 318 /// `RegionBranchTerminatorOpInterface` and is not marked as `ReturnLike`, the 319 /// result will be `llvm::None`. In all other cases, the resulting 320 /// `OperandRange` represents all operands that are passed to the specified 321 /// successor region. If `regionIndex` is `llvm::None`, all operands that are 322 /// passed to the parent operation will be returned. 323 Optional<MutableOperandRange> 324 mlir::getMutableRegionBranchSuccessorOperands(Operation *operation, 325 Optional<unsigned> regionIndex) { 326 // Try to query a RegionBranchTerminatorOpInterface to determine 327 // all successor operands that will be passed to the successor 328 // input arguments. 329 if (auto regionTerminatorInterface = 330 dyn_cast<RegionBranchTerminatorOpInterface>(operation)) 331 return regionTerminatorInterface.getMutableSuccessorOperands(regionIndex); 332 333 // TODO: The ReturnLike trait should imply a default implementation of the 334 // RegionBranchTerminatorOpInterface. This would make this code significantly 335 // easier. Furthermore, this may even make this function obsolete. 336 if (operation->hasTrait<OpTrait::ReturnLike>()) 337 return MutableOperandRange(operation); 338 return llvm::None; 339 } 340 341 /// Returns the read only operands that are passed to the region with the given 342 /// `regionIndex`. See `getMutableRegionBranchSuccessorOperands` for more 343 /// information. 344 Optional<OperandRange> 345 mlir::getRegionBranchSuccessorOperands(Operation *operation, 346 Optional<unsigned> regionIndex) { 347 auto range = getMutableRegionBranchSuccessorOperands(operation, regionIndex); 348 return range ? Optional<OperandRange>(*range) : llvm::None; 349 } 350