1 //===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Interfaces/ControlFlowInterfaces.h" 10 #include "mlir/IR/BuiltinTypes.h" 11 #include "llvm/ADT/SmallPtrSet.h" 12 13 using namespace mlir; 14 15 //===----------------------------------------------------------------------===// 16 // ControlFlowInterfaces 17 //===----------------------------------------------------------------------===// 18 19 #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc" 20 21 SuccessorOperands::SuccessorOperands(MutableOperandRange forwardedOperands) 22 : producedOperandCount(0), forwardedOperands(forwardedOperands) {} 23 24 SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount, 25 MutableOperandRange forwardedOperands) 26 : producedOperandCount(producedOperandCount), 27 forwardedOperands(std::move(forwardedOperands)) {} 28 29 //===----------------------------------------------------------------------===// 30 // BranchOpInterface 31 //===----------------------------------------------------------------------===// 32 33 /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some 34 /// successor if 'operandIndex' is within the range of 'operands', or None if 35 /// `operandIndex` isn't a successor operand index. 36 Optional<BlockArgument> 37 detail::getBranchSuccessorArgument(const SuccessorOperands &operands, 38 unsigned operandIndex, Block *successor) { 39 OperandRange forwardedOperands = operands.getForwardedOperands(); 40 // Check that the operands are valid. 41 if (forwardedOperands.empty()) 42 return llvm::None; 43 44 // Check to ensure that this operand is within the range. 45 unsigned operandsStart = forwardedOperands.getBeginOperandIndex(); 46 if (operandIndex < operandsStart || 47 operandIndex >= (operandsStart + forwardedOperands.size())) 48 return llvm::None; 49 50 // Index the successor. 51 unsigned argIndex = 52 operands.getProducedOperandCount() + operandIndex - operandsStart; 53 return successor->getArgument(argIndex); 54 } 55 56 /// Verify that the given operands match those of the given successor block. 57 LogicalResult 58 detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo, 59 const SuccessorOperands &operands) { 60 // Check the count. 61 unsigned operandCount = operands.size(); 62 Block *destBB = op->getSuccessor(succNo); 63 if (operandCount != destBB->getNumArguments()) 64 return op->emitError() << "branch has " << operandCount 65 << " operands for successor #" << succNo 66 << ", but target block has " 67 << destBB->getNumArguments(); 68 69 // Check the types. 70 for (unsigned i = operands.getProducedOperandCount(); i != operandCount; 71 ++i) { 72 if (!cast<BranchOpInterface>(op).areTypesCompatible( 73 operands[i].getType(), destBB->getArgument(i).getType())) 74 return op->emitError() << "type mismatch for bb argument #" << i 75 << " of successor #" << succNo; 76 } 77 return success(); 78 } 79 80 //===----------------------------------------------------------------------===// 81 // RegionBranchOpInterface 82 //===----------------------------------------------------------------------===// 83 84 /// Verify that types match along all region control flow edges originating from 85 /// `sourceNo` (region # if source is a region, llvm::None if source is parent 86 /// op). `getInputsTypesForRegion` is a function that returns the types of the 87 /// inputs that flow from `sourceIndex' to the given region, or llvm::None if 88 /// the exact type match verification is not necessary (e.g., if the Op verifies 89 /// the match itself). 90 static LogicalResult 91 verifyTypesAlongAllEdges(Operation *op, Optional<unsigned> sourceNo, 92 function_ref<Optional<TypeRange>(Optional<unsigned>)> 93 getInputsTypesForRegion) { 94 auto regionInterface = cast<RegionBranchOpInterface>(op); 95 96 SmallVector<RegionSuccessor, 2> successors; 97 unsigned numInputs; 98 if (sourceNo) { 99 Region &srcRegion = op->getRegion(sourceNo.getValue()); 100 numInputs = srcRegion.getNumArguments(); 101 } else { 102 numInputs = op->getNumOperands(); 103 } 104 SmallVector<Attribute, 2> operands(numInputs, nullptr); 105 regionInterface.getSuccessorRegions(sourceNo, operands, successors); 106 107 for (RegionSuccessor &succ : successors) { 108 Optional<unsigned> succRegionNo; 109 if (!succ.isParent()) 110 succRegionNo = succ.getSuccessor()->getRegionNumber(); 111 112 auto printEdgeName = [&](InFlightDiagnostic &diag) -> InFlightDiagnostic & { 113 diag << "from "; 114 if (sourceNo) 115 diag << "Region #" << sourceNo.getValue(); 116 else 117 diag << "parent operands"; 118 119 diag << " to "; 120 if (succRegionNo) 121 diag << "Region #" << succRegionNo.getValue(); 122 else 123 diag << "parent results"; 124 return diag; 125 }; 126 127 Optional<TypeRange> sourceTypes = getInputsTypesForRegion(succRegionNo); 128 if (!sourceTypes.hasValue()) 129 continue; 130 131 TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes(); 132 if (sourceTypes->size() != succInputsTypes.size()) { 133 InFlightDiagnostic diag = op->emitOpError(" region control flow edge "); 134 return printEdgeName(diag) << ": source has " << sourceTypes->size() 135 << " operands, but target successor needs " 136 << succInputsTypes.size(); 137 } 138 139 for (const auto &typesIdx : 140 llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) { 141 Type sourceType = std::get<0>(typesIdx.value()); 142 Type inputType = std::get<1>(typesIdx.value()); 143 if (!regionInterface.areTypesCompatible(sourceType, inputType)) { 144 InFlightDiagnostic diag = op->emitOpError(" along control flow edge "); 145 return printEdgeName(diag) 146 << ": source type #" << typesIdx.index() << " " << sourceType 147 << " should match input type #" << typesIdx.index() << " " 148 << inputType; 149 } 150 } 151 } 152 return success(); 153 } 154 155 /// Verify that types match along control flow edges described the given op. 156 LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) { 157 auto regionInterface = cast<RegionBranchOpInterface>(op); 158 159 auto inputTypesFromParent = [&](Optional<unsigned> regionNo) -> TypeRange { 160 if (regionNo.hasValue()) { 161 return regionInterface.getSuccessorEntryOperands(regionNo.getValue()) 162 .getTypes(); 163 } 164 165 // If the successor of a parent op is the parent itself 166 // RegionBranchOpInterface does not have an API to query what the entry 167 // operands will be in that case. Vend out the result types of the op in 168 // that case so that type checking succeeds for this case. 169 return op->getResultTypes(); 170 }; 171 172 // Verify types along control flow edges originating from the parent. 173 if (failed(verifyTypesAlongAllEdges(op, llvm::None, inputTypesFromParent))) 174 return failure(); 175 176 // RegionBranchOpInterface should not be implemented by Ops that do not have 177 // attached regions. 178 assert(op->getNumRegions() != 0); 179 180 auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) { 181 if (lhs.size() != rhs.size()) 182 return false; 183 for (auto types : llvm::zip(lhs, rhs)) { 184 if (!regionInterface.areTypesCompatible(std::get<0>(types), 185 std::get<1>(types))) { 186 return false; 187 } 188 } 189 return true; 190 }; 191 192 // Verify types along control flow edges originating from each region. 193 for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) { 194 Region ®ion = op->getRegion(regionNo); 195 196 // Since there can be multiple `ReturnLike` terminators or others 197 // implementing the `RegionBranchTerminatorOpInterface`, all should have the 198 // same operand types when passing them to the same region. 199 200 Optional<OperandRange> regionReturnOperands; 201 for (Block &block : region) { 202 Operation *terminator = block.getTerminator(); 203 auto terminatorOperands = 204 getRegionBranchSuccessorOperands(terminator, regionNo); 205 if (!terminatorOperands) 206 continue; 207 208 if (!regionReturnOperands) { 209 regionReturnOperands = terminatorOperands; 210 continue; 211 } 212 213 // Found more than one ReturnLike terminator. Make sure the operand types 214 // match with the first one. 215 if (!areTypesCompatible(regionReturnOperands->getTypes(), 216 terminatorOperands->getTypes())) 217 return op->emitOpError("Region #") 218 << regionNo 219 << " operands mismatch between return-like terminators"; 220 } 221 222 auto inputTypesFromRegion = 223 [&](Optional<unsigned> regionNo) -> Optional<TypeRange> { 224 // If there is no return-like terminator, the op itself should verify 225 // type consistency. 226 if (!regionReturnOperands) 227 return llvm::None; 228 229 // All successors get the same set of operand types. 230 return TypeRange(regionReturnOperands->getTypes()); 231 }; 232 233 if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesFromRegion))) 234 return failure(); 235 } 236 237 return success(); 238 } 239 240 /// Return `true` if region `r` is reachable from region `begin` according to 241 /// the RegionBranchOpInterface (by taking a branch). 242 static bool isRegionReachable(Region *begin, Region *r) { 243 assert(begin->getParentOp() == r->getParentOp() && 244 "expected that both regions belong to the same op"); 245 auto op = cast<RegionBranchOpInterface>(begin->getParentOp()); 246 SmallVector<bool> visited(op->getNumRegions(), false); 247 visited[begin->getRegionNumber()] = true; 248 249 // Retrieve all successors of the region and enqueue them in the worklist. 250 SmallVector<unsigned> worklist; 251 auto enqueueAllSuccessors = [&](unsigned index) { 252 SmallVector<RegionSuccessor> successors; 253 op.getSuccessorRegions(index, successors); 254 for (RegionSuccessor successor : successors) 255 if (!successor.isParent()) 256 worklist.push_back(successor.getSuccessor()->getRegionNumber()); 257 }; 258 enqueueAllSuccessors(begin->getRegionNumber()); 259 260 // Process all regions in the worklist via DFS. 261 while (!worklist.empty()) { 262 unsigned nextRegion = worklist.pop_back_val(); 263 if (nextRegion == r->getRegionNumber()) 264 return true; 265 if (visited[nextRegion]) 266 continue; 267 visited[nextRegion] = true; 268 enqueueAllSuccessors(nextRegion); 269 } 270 271 return false; 272 } 273 274 /// Return `true` if `a` and `b` are in mutually exclusive regions. 275 /// 276 /// 1. Find the first common of `a` and `b` (ancestor) that implements 277 /// RegionBranchOpInterface. 278 /// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are 279 /// contained. 280 /// 3. Check if `regionA` and `regionB` are mutually exclusive. They are 281 /// mutually exclusive if they are not reachable from each other as per 282 /// RegionBranchOpInterface::getSuccessorRegions. 283 bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) { 284 assert(a && "expected non-empty operation"); 285 assert(b && "expected non-empty operation"); 286 287 auto branchOp = a->getParentOfType<RegionBranchOpInterface>(); 288 while (branchOp) { 289 // Check if b is inside branchOp. (We already know that a is.) 290 if (!branchOp->isProperAncestor(b)) { 291 // Check next enclosing RegionBranchOpInterface. 292 branchOp = branchOp->getParentOfType<RegionBranchOpInterface>(); 293 continue; 294 } 295 296 // b is contained in branchOp. Retrieve the regions in which `a` and `b` 297 // are contained. 298 Region *regionA = nullptr, *regionB = nullptr; 299 for (Region &r : branchOp->getRegions()) { 300 if (r.findAncestorOpInRegion(*a)) { 301 assert(!regionA && "already found a region for a"); 302 regionA = &r; 303 } 304 if (r.findAncestorOpInRegion(*b)) { 305 assert(!regionB && "already found a region for b"); 306 regionB = &r; 307 } 308 } 309 assert(regionA && regionB && "could not find region of op"); 310 311 // `a` and `b` are in mutually exclusive regions if both regions are 312 // distinct and neither region is reachable from the other region. 313 return regionA != regionB && !isRegionReachable(regionA, regionB) && 314 !isRegionReachable(regionB, regionA); 315 } 316 317 // Could not find a common RegionBranchOpInterface among a's and b's 318 // ancestors. 319 return false; 320 } 321 322 bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) { 323 Region *region = &getOperation()->getRegion(index); 324 return isRegionReachable(region, region); 325 } 326 327 Region *mlir::getEnclosingRepetitiveRegion(Operation *op) { 328 while (Region *region = op->getParentRegion()) { 329 op = region->getParentOp(); 330 if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) 331 if (branchOp.isRepetitiveRegion(region->getRegionNumber())) 332 return region; 333 } 334 return nullptr; 335 } 336 337 Region *mlir::getEnclosingRepetitiveRegion(Value value) { 338 Region *region = value.getParentRegion(); 339 while (region) { 340 Operation *op = region->getParentOp(); 341 if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) 342 if (branchOp.isRepetitiveRegion(region->getRegionNumber())) 343 return region; 344 region = op->getParentRegion(); 345 } 346 return nullptr; 347 } 348 349 //===----------------------------------------------------------------------===// 350 // RegionBranchTerminatorOpInterface 351 //===----------------------------------------------------------------------===// 352 353 /// Returns true if the given operation is either annotated with the 354 /// `ReturnLike` trait or implements the `RegionBranchTerminatorOpInterface`. 355 bool mlir::isRegionReturnLike(Operation *operation) { 356 return dyn_cast<RegionBranchTerminatorOpInterface>(operation) || 357 operation->hasTrait<OpTrait::ReturnLike>(); 358 } 359 360 /// Returns the mutable operands that are passed to the region with the given 361 /// `regionIndex`. If the operation does not implement the 362 /// `RegionBranchTerminatorOpInterface` and is not marked as `ReturnLike`, the 363 /// result will be `llvm::None`. In all other cases, the resulting 364 /// `OperandRange` represents all operands that are passed to the specified 365 /// successor region. If `regionIndex` is `llvm::None`, all operands that are 366 /// passed to the parent operation will be returned. 367 Optional<MutableOperandRange> 368 mlir::getMutableRegionBranchSuccessorOperands(Operation *operation, 369 Optional<unsigned> regionIndex) { 370 // Try to query a RegionBranchTerminatorOpInterface to determine 371 // all successor operands that will be passed to the successor 372 // input arguments. 373 if (auto regionTerminatorInterface = 374 dyn_cast<RegionBranchTerminatorOpInterface>(operation)) 375 return regionTerminatorInterface.getMutableSuccessorOperands(regionIndex); 376 377 // TODO: The ReturnLike trait should imply a default implementation of the 378 // RegionBranchTerminatorOpInterface. This would make this code significantly 379 // easier. Furthermore, this may even make this function obsolete. 380 if (operation->hasTrait<OpTrait::ReturnLike>()) 381 return MutableOperandRange(operation); 382 return llvm::None; 383 } 384 385 /// Returns the read only operands that are passed to the region with the given 386 /// `regionIndex`. See `getMutableRegionBranchSuccessorOperands` for more 387 /// information. 388 Optional<OperandRange> 389 mlir::getRegionBranchSuccessorOperands(Operation *operation, 390 Optional<unsigned> regionIndex) { 391 auto range = getMutableRegionBranchSuccessorOperands(operation, regionIndex); 392 return range ? Optional<OperandRange>(*range) : llvm::None; 393 } 394