1 //===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Interfaces/ControlFlowInterfaces.h" 10 #include "mlir/IR/BuiltinTypes.h" 11 #include "llvm/ADT/SmallPtrSet.h" 12 13 using namespace mlir; 14 15 //===----------------------------------------------------------------------===// 16 // ControlFlowInterfaces 17 //===----------------------------------------------------------------------===// 18 19 #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc" 20 21 SuccessorOperands::SuccessorOperands(MutableOperandRange forwardedOperands) 22 : producedOperandCount(0), forwardedOperands(forwardedOperands) {} 23 24 SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount, 25 MutableOperandRange forwardedOperands) 26 : producedOperandCount(producedOperandCount), 27 forwardedOperands(std::move(forwardedOperands)) {} 28 29 //===----------------------------------------------------------------------===// 30 // BranchOpInterface 31 //===----------------------------------------------------------------------===// 32 33 /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some 34 /// successor if 'operandIndex' is within the range of 'operands', or None if 35 /// `operandIndex` isn't a successor operand index. 36 Optional<BlockArgument> 37 detail::getBranchSuccessorArgument(const SuccessorOperands &operands, 38 unsigned operandIndex, Block *successor) { 39 OperandRange forwardedOperands = operands.getForwardedOperands(); 40 // Check that the operands are valid. 41 if (forwardedOperands.empty()) 42 return llvm::None; 43 44 // Check to ensure that this operand is within the range. 45 unsigned operandsStart = forwardedOperands.getBeginOperandIndex(); 46 if (operandIndex < operandsStart || 47 operandIndex >= (operandsStart + forwardedOperands.size())) 48 return llvm::None; 49 50 // Index the successor. 51 unsigned argIndex = 52 operands.getProducedOperandCount() + operandIndex - operandsStart; 53 return successor->getArgument(argIndex); 54 } 55 56 /// Verify that the given operands match those of the given successor block. 57 LogicalResult 58 detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo, 59 const SuccessorOperands &operands) { 60 // Check the count. 61 unsigned operandCount = operands.size(); 62 Block *destBB = op->getSuccessor(succNo); 63 if (operandCount != destBB->getNumArguments()) 64 return op->emitError() << "branch has " << operandCount 65 << " operands for successor #" << succNo 66 << ", but target block has " 67 << destBB->getNumArguments(); 68 69 // Check the types. 70 for (unsigned i = operands.getProducedOperandCount(); i != operandCount; 71 ++i) { 72 if (!cast<BranchOpInterface>(op).areTypesCompatible( 73 operands[i].getType(), destBB->getArgument(i).getType())) 74 return op->emitError() << "type mismatch for bb argument #" << i 75 << " of successor #" << succNo; 76 } 77 return success(); 78 } 79 80 //===----------------------------------------------------------------------===// 81 // RegionBranchOpInterface 82 //===----------------------------------------------------------------------===// 83 84 /// Verify that types match along all region control flow edges originating from 85 /// `sourceNo` (region # if source is a region, llvm::None if source is parent 86 /// op). `getInputsTypesForRegion` is a function that returns the types of the 87 /// inputs that flow from `sourceIndex' to the given region, or llvm::None if 88 /// the exact type match verification is not necessary (e.g., if the Op verifies 89 /// the match itself). 90 static LogicalResult 91 verifyTypesAlongAllEdges(Operation *op, Optional<unsigned> sourceNo, 92 function_ref<Optional<TypeRange>(Optional<unsigned>)> 93 getInputsTypesForRegion) { 94 auto regionInterface = cast<RegionBranchOpInterface>(op); 95 96 SmallVector<RegionSuccessor, 2> successors; 97 unsigned numInputs; 98 if (sourceNo) { 99 Region &srcRegion = op->getRegion(sourceNo.getValue()); 100 numInputs = srcRegion.getNumArguments(); 101 } else { 102 numInputs = op->getNumOperands(); 103 } 104 SmallVector<Attribute, 2> operands(numInputs, nullptr); 105 regionInterface.getSuccessorRegions(sourceNo, operands, successors); 106 107 for (RegionSuccessor &succ : successors) { 108 Optional<unsigned> succRegionNo; 109 if (!succ.isParent()) 110 succRegionNo = succ.getSuccessor()->getRegionNumber(); 111 112 auto printEdgeName = [&](InFlightDiagnostic &diag) -> InFlightDiagnostic & { 113 diag << "from "; 114 if (sourceNo) 115 diag << "Region #" << sourceNo.getValue(); 116 else 117 diag << "parent operands"; 118 119 diag << " to "; 120 if (succRegionNo) 121 diag << "Region #" << succRegionNo.getValue(); 122 else 123 diag << "parent results"; 124 return diag; 125 }; 126 127 Optional<TypeRange> sourceTypes = getInputsTypesForRegion(succRegionNo); 128 if (!sourceTypes.hasValue()) 129 continue; 130 131 TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes(); 132 if (sourceTypes->size() != succInputsTypes.size()) { 133 InFlightDiagnostic diag = op->emitOpError(" region control flow edge "); 134 return printEdgeName(diag) << ": source has " << sourceTypes->size() 135 << " operands, but target successor needs " 136 << succInputsTypes.size(); 137 } 138 139 for (const auto &typesIdx : 140 llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) { 141 Type sourceType = std::get<0>(typesIdx.value()); 142 Type inputType = std::get<1>(typesIdx.value()); 143 if (!regionInterface.areTypesCompatible(sourceType, inputType)) { 144 InFlightDiagnostic diag = op->emitOpError(" along control flow edge "); 145 return printEdgeName(diag) 146 << ": source type #" << typesIdx.index() << " " << sourceType 147 << " should match input type #" << typesIdx.index() << " " 148 << inputType; 149 } 150 } 151 } 152 return success(); 153 } 154 155 /// Verify that types match along control flow edges described the given op. 156 LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) { 157 auto regionInterface = cast<RegionBranchOpInterface>(op); 158 159 auto inputTypesFromParent = [&](Optional<unsigned> regionNo) -> TypeRange { 160 if (regionNo.hasValue()) { 161 return regionInterface.getSuccessorEntryOperands(regionNo.getValue()) 162 .getTypes(); 163 } 164 165 // If the successor of a parent op is the parent itself 166 // RegionBranchOpInterface does not have an API to query what the entry 167 // operands will be in that case. Vend out the result types of the op in 168 // that case so that type checking succeeds for this case. 169 return op->getResultTypes(); 170 }; 171 172 // Verify types along control flow edges originating from the parent. 173 if (failed(verifyTypesAlongAllEdges(op, llvm::None, inputTypesFromParent))) 174 return failure(); 175 176 // RegionBranchOpInterface should not be implemented by Ops that do not have 177 // attached regions. 178 assert(op->getNumRegions() != 0); 179 180 auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) { 181 if (lhs.size() != rhs.size()) 182 return false; 183 for (auto types : llvm::zip(lhs, rhs)) { 184 if (!regionInterface.areTypesCompatible(std::get<0>(types), 185 std::get<1>(types))) { 186 return false; 187 } 188 } 189 return true; 190 }; 191 192 // Verify types along control flow edges originating from each region. 193 for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) { 194 Region ®ion = op->getRegion(regionNo); 195 196 // Since there can be multiple `ReturnLike` terminators or others 197 // implementing the `RegionBranchTerminatorOpInterface`, all should have the 198 // same operand types when passing them to the same region. 199 200 Optional<OperandRange> regionReturnOperands; 201 for (Block &block : region) { 202 Operation *terminator = block.getTerminator(); 203 auto terminatorOperands = 204 getRegionBranchSuccessorOperands(terminator, regionNo); 205 if (!terminatorOperands) 206 continue; 207 208 if (!regionReturnOperands) { 209 regionReturnOperands = terminatorOperands; 210 continue; 211 } 212 213 // Found more than one ReturnLike terminator. Make sure the operand types 214 // match with the first one. 215 if (!areTypesCompatible(regionReturnOperands->getTypes(), 216 terminatorOperands->getTypes())) 217 return op->emitOpError("Region #") 218 << regionNo 219 << " operands mismatch between return-like terminators"; 220 } 221 222 auto inputTypesFromRegion = 223 [&](Optional<unsigned> regionNo) -> Optional<TypeRange> { 224 // If there is no return-like terminator, the op itself should verify 225 // type consistency. 226 if (!regionReturnOperands) 227 return llvm::None; 228 229 // All successors get the same set of operand types. 230 return TypeRange(regionReturnOperands->getTypes()); 231 }; 232 233 if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesFromRegion))) 234 return failure(); 235 } 236 237 return success(); 238 } 239 240 /// Return `true` if `a` and `b` are in mutually exclusive regions. 241 /// 242 /// 1. Find the first common of `a` and `b` (ancestor) that implements 243 /// RegionBranchOpInterface. 244 /// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are 245 /// contained. 246 /// 3. Check if `regionA` and `regionB` are mutually exclusive. They are 247 /// mutually exclusive if they are not reachable from each other as per 248 /// RegionBranchOpInterface::getSuccessorRegions. 249 bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) { 250 assert(a && "expected non-empty operation"); 251 assert(b && "expected non-empty operation"); 252 253 auto branchOp = a->getParentOfType<RegionBranchOpInterface>(); 254 while (branchOp) { 255 // Check if b is inside branchOp. (We already know that a is.) 256 if (!branchOp->isProperAncestor(b)) { 257 // Check next enclosing RegionBranchOpInterface. 258 branchOp = branchOp->getParentOfType<RegionBranchOpInterface>(); 259 continue; 260 } 261 262 // b is contained in branchOp. Retrieve the regions in which `a` and `b` 263 // are contained. 264 Region *regionA = nullptr, *regionB = nullptr; 265 for (Region &r : branchOp->getRegions()) { 266 if (r.findAncestorOpInRegion(*a)) { 267 assert(!regionA && "already found a region for a"); 268 regionA = &r; 269 } 270 if (r.findAncestorOpInRegion(*b)) { 271 assert(!regionB && "already found a region for b"); 272 regionB = &r; 273 } 274 } 275 assert(regionA && regionB && "could not find region of op"); 276 277 // Helper function that checks if region `r` is reachable from region 278 // `begin`. 279 std::function<bool(Region *, Region *)> isRegionReachable = 280 [&](Region *begin, Region *r) { 281 if (begin == r) 282 return true; 283 if (begin == nullptr) 284 return false; 285 // Compute index of region. 286 int64_t beginIndex = -1; 287 for (const auto &it : llvm::enumerate(branchOp->getRegions())) 288 if (&it.value() == begin) 289 beginIndex = it.index(); 290 assert(beginIndex != -1 && "could not find region in op"); 291 // Retrieve all successors of the region. 292 SmallVector<RegionSuccessor> successors; 293 branchOp.getSuccessorRegions(beginIndex, successors); 294 // Call function recursively on all successors. 295 for (RegionSuccessor successor : successors) 296 if (isRegionReachable(successor.getSuccessor(), r)) 297 return true; 298 return false; 299 }; 300 301 // `a` and `b` are in mutually exclusive regions if neither region is 302 // reachable from the other region. 303 return !isRegionReachable(regionA, regionB) && 304 !isRegionReachable(regionB, regionA); 305 } 306 307 // Could not find a common RegionBranchOpInterface among a's and b's 308 // ancestors. 309 return false; 310 } 311 312 //===----------------------------------------------------------------------===// 313 // RegionBranchTerminatorOpInterface 314 //===----------------------------------------------------------------------===// 315 316 /// Returns true if the given operation is either annotated with the 317 /// `ReturnLike` trait or implements the `RegionBranchTerminatorOpInterface`. 318 bool mlir::isRegionReturnLike(Operation *operation) { 319 return dyn_cast<RegionBranchTerminatorOpInterface>(operation) || 320 operation->hasTrait<OpTrait::ReturnLike>(); 321 } 322 323 /// Returns the mutable operands that are passed to the region with the given 324 /// `regionIndex`. If the operation does not implement the 325 /// `RegionBranchTerminatorOpInterface` and is not marked as `ReturnLike`, the 326 /// result will be `llvm::None`. In all other cases, the resulting 327 /// `OperandRange` represents all operands that are passed to the specified 328 /// successor region. If `regionIndex` is `llvm::None`, all operands that are 329 /// passed to the parent operation will be returned. 330 Optional<MutableOperandRange> 331 mlir::getMutableRegionBranchSuccessorOperands(Operation *operation, 332 Optional<unsigned> regionIndex) { 333 // Try to query a RegionBranchTerminatorOpInterface to determine 334 // all successor operands that will be passed to the successor 335 // input arguments. 336 if (auto regionTerminatorInterface = 337 dyn_cast<RegionBranchTerminatorOpInterface>(operation)) 338 return regionTerminatorInterface.getMutableSuccessorOperands(regionIndex); 339 340 // TODO: The ReturnLike trait should imply a default implementation of the 341 // RegionBranchTerminatorOpInterface. This would make this code significantly 342 // easier. Furthermore, this may even make this function obsolete. 343 if (operation->hasTrait<OpTrait::ReturnLike>()) 344 return MutableOperandRange(operation); 345 return llvm::None; 346 } 347 348 /// Returns the read only operands that are passed to the region with the given 349 /// `regionIndex`. See `getMutableRegionBranchSuccessorOperands` for more 350 /// information. 351 Optional<OperandRange> 352 mlir::getRegionBranchSuccessorOperands(Operation *operation, 353 Optional<unsigned> regionIndex) { 354 auto range = getMutableRegionBranchSuccessorOperands(operation, regionIndex); 355 return range ? Optional<OperandRange>(*range) : llvm::None; 356 } 357