1 //===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "mlir/Interfaces/ControlFlowInterfaces.h" 12 #include "mlir/IR/BuiltinTypes.h" 13 #include "llvm/ADT/SmallPtrSet.h" 14 15 using namespace mlir; 16 17 //===----------------------------------------------------------------------===// 18 // ControlFlowInterfaces 19 //===----------------------------------------------------------------------===// 20 21 #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc" 22 23 SuccessorOperands::SuccessorOperands(MutableOperandRange forwardedOperands) 24 : producedOperandCount(0), forwardedOperands(std::move(forwardedOperands)) { 25 } 26 27 SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount, 28 MutableOperandRange forwardedOperands) 29 : producedOperandCount(producedOperandCount), 30 forwardedOperands(std::move(forwardedOperands)) {} 31 32 //===----------------------------------------------------------------------===// 33 // BranchOpInterface 34 //===----------------------------------------------------------------------===// 35 36 /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some 37 /// successor if 'operandIndex' is within the range of 'operands', or None if 38 /// `operandIndex` isn't a successor operand index. 39 Optional<BlockArgument> 40 detail::getBranchSuccessorArgument(const SuccessorOperands &operands, 41 unsigned operandIndex, Block *successor) { 42 OperandRange forwardedOperands = operands.getForwardedOperands(); 43 // Check that the operands are valid. 44 if (forwardedOperands.empty()) 45 return llvm::None; 46 47 // Check to ensure that this operand is within the range. 48 unsigned operandsStart = forwardedOperands.getBeginOperandIndex(); 49 if (operandIndex < operandsStart || 50 operandIndex >= (operandsStart + forwardedOperands.size())) 51 return llvm::None; 52 53 // Index the successor. 54 unsigned argIndex = 55 operands.getProducedOperandCount() + operandIndex - operandsStart; 56 return successor->getArgument(argIndex); 57 } 58 59 /// Verify that the given operands match those of the given successor block. 60 LogicalResult 61 detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo, 62 const SuccessorOperands &operands) { 63 // Check the count. 64 unsigned operandCount = operands.size(); 65 Block *destBB = op->getSuccessor(succNo); 66 if (operandCount != destBB->getNumArguments()) 67 return op->emitError() << "branch has " << operandCount 68 << " operands for successor #" << succNo 69 << ", but target block has " 70 << destBB->getNumArguments(); 71 72 // Check the types. 73 for (unsigned i = operands.getProducedOperandCount(); i != operandCount; 74 ++i) { 75 if (!cast<BranchOpInterface>(op).areTypesCompatible( 76 operands[i].getType(), destBB->getArgument(i).getType())) 77 return op->emitError() << "type mismatch for bb argument #" << i 78 << " of successor #" << succNo; 79 } 80 return success(); 81 } 82 83 //===----------------------------------------------------------------------===// 84 // RegionBranchOpInterface 85 //===----------------------------------------------------------------------===// 86 87 /// Verify that types match along all region control flow edges originating from 88 /// `sourceNo` (region # if source is a region, llvm::None if source is parent 89 /// op). `getInputsTypesForRegion` is a function that returns the types of the 90 /// inputs that flow from `sourceIndex' to the given region, or llvm::None if 91 /// the exact type match verification is not necessary (e.g., if the Op verifies 92 /// the match itself). 93 static LogicalResult 94 verifyTypesAlongAllEdges(Operation *op, Optional<unsigned> sourceNo, 95 function_ref<Optional<TypeRange>(Optional<unsigned>)> 96 getInputsTypesForRegion) { 97 auto regionInterface = cast<RegionBranchOpInterface>(op); 98 99 SmallVector<RegionSuccessor, 2> successors; 100 unsigned numInputs; 101 if (sourceNo) { 102 Region &srcRegion = op->getRegion(sourceNo.getValue()); 103 numInputs = srcRegion.getNumArguments(); 104 } else { 105 numInputs = op->getNumOperands(); 106 } 107 SmallVector<Attribute, 2> operands(numInputs, nullptr); 108 regionInterface.getSuccessorRegions(sourceNo, operands, successors); 109 110 for (RegionSuccessor &succ : successors) { 111 Optional<unsigned> succRegionNo; 112 if (!succ.isParent()) 113 succRegionNo = succ.getSuccessor()->getRegionNumber(); 114 115 auto printEdgeName = [&](InFlightDiagnostic &diag) -> InFlightDiagnostic & { 116 diag << "from "; 117 if (sourceNo) 118 diag << "Region #" << sourceNo.getValue(); 119 else 120 diag << "parent operands"; 121 122 diag << " to "; 123 if (succRegionNo) 124 diag << "Region #" << succRegionNo.getValue(); 125 else 126 diag << "parent results"; 127 return diag; 128 }; 129 130 Optional<TypeRange> sourceTypes = getInputsTypesForRegion(succRegionNo); 131 if (!sourceTypes.hasValue()) 132 continue; 133 134 TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes(); 135 if (sourceTypes->size() != succInputsTypes.size()) { 136 InFlightDiagnostic diag = op->emitOpError(" region control flow edge "); 137 return printEdgeName(diag) << ": source has " << sourceTypes->size() 138 << " operands, but target successor needs " 139 << succInputsTypes.size(); 140 } 141 142 for (const auto &typesIdx : 143 llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) { 144 Type sourceType = std::get<0>(typesIdx.value()); 145 Type inputType = std::get<1>(typesIdx.value()); 146 if (!regionInterface.areTypesCompatible(sourceType, inputType)) { 147 InFlightDiagnostic diag = op->emitOpError(" along control flow edge "); 148 return printEdgeName(diag) 149 << ": source type #" << typesIdx.index() << " " << sourceType 150 << " should match input type #" << typesIdx.index() << " " 151 << inputType; 152 } 153 } 154 } 155 return success(); 156 } 157 158 /// Verify that types match along control flow edges described the given op. 159 LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) { 160 auto regionInterface = cast<RegionBranchOpInterface>(op); 161 162 auto inputTypesFromParent = [&](Optional<unsigned> regionNo) -> TypeRange { 163 if (regionNo.hasValue()) { 164 return regionInterface.getSuccessorEntryOperands(regionNo.getValue()) 165 .getTypes(); 166 } 167 168 // If the successor of a parent op is the parent itself 169 // RegionBranchOpInterface does not have an API to query what the entry 170 // operands will be in that case. Vend out the result types of the op in 171 // that case so that type checking succeeds for this case. 172 return op->getResultTypes(); 173 }; 174 175 // Verify types along control flow edges originating from the parent. 176 if (failed(verifyTypesAlongAllEdges(op, llvm::None, inputTypesFromParent))) 177 return failure(); 178 179 // RegionBranchOpInterface should not be implemented by Ops that do not have 180 // attached regions. 181 assert(op->getNumRegions() != 0); 182 183 auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) { 184 if (lhs.size() != rhs.size()) 185 return false; 186 for (auto types : llvm::zip(lhs, rhs)) { 187 if (!regionInterface.areTypesCompatible(std::get<0>(types), 188 std::get<1>(types))) { 189 return false; 190 } 191 } 192 return true; 193 }; 194 195 // Verify types along control flow edges originating from each region. 196 for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) { 197 Region ®ion = op->getRegion(regionNo); 198 199 // Since there can be multiple `ReturnLike` terminators or others 200 // implementing the `RegionBranchTerminatorOpInterface`, all should have the 201 // same operand types when passing them to the same region. 202 203 Optional<OperandRange> regionReturnOperands; 204 for (Block &block : region) { 205 Operation *terminator = block.getTerminator(); 206 auto terminatorOperands = 207 getRegionBranchSuccessorOperands(terminator, regionNo); 208 if (!terminatorOperands) 209 continue; 210 211 if (!regionReturnOperands) { 212 regionReturnOperands = terminatorOperands; 213 continue; 214 } 215 216 // Found more than one ReturnLike terminator. Make sure the operand types 217 // match with the first one. 218 if (!areTypesCompatible(regionReturnOperands->getTypes(), 219 terminatorOperands->getTypes())) 220 return op->emitOpError("Region #") 221 << regionNo 222 << " operands mismatch between return-like terminators"; 223 } 224 225 auto inputTypesFromRegion = 226 [&](Optional<unsigned> regionNo) -> Optional<TypeRange> { 227 // If there is no return-like terminator, the op itself should verify 228 // type consistency. 229 if (!regionReturnOperands) 230 return llvm::None; 231 232 // All successors get the same set of operand types. 233 return TypeRange(regionReturnOperands->getTypes()); 234 }; 235 236 if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesFromRegion))) 237 return failure(); 238 } 239 240 return success(); 241 } 242 243 /// Return `true` if region `r` is reachable from region `begin` according to 244 /// the RegionBranchOpInterface (by taking a branch). 245 static bool isRegionReachable(Region *begin, Region *r) { 246 assert(begin->getParentOp() == r->getParentOp() && 247 "expected that both regions belong to the same op"); 248 auto op = cast<RegionBranchOpInterface>(begin->getParentOp()); 249 SmallVector<bool> visited(op->getNumRegions(), false); 250 visited[begin->getRegionNumber()] = true; 251 252 // Retrieve all successors of the region and enqueue them in the worklist. 253 SmallVector<unsigned> worklist; 254 auto enqueueAllSuccessors = [&](unsigned index) { 255 SmallVector<RegionSuccessor> successors; 256 op.getSuccessorRegions(index, successors); 257 for (RegionSuccessor successor : successors) 258 if (!successor.isParent()) 259 worklist.push_back(successor.getSuccessor()->getRegionNumber()); 260 }; 261 enqueueAllSuccessors(begin->getRegionNumber()); 262 263 // Process all regions in the worklist via DFS. 264 while (!worklist.empty()) { 265 unsigned nextRegion = worklist.pop_back_val(); 266 if (nextRegion == r->getRegionNumber()) 267 return true; 268 if (visited[nextRegion]) 269 continue; 270 visited[nextRegion] = true; 271 enqueueAllSuccessors(nextRegion); 272 } 273 274 return false; 275 } 276 277 /// Return `true` if `a` and `b` are in mutually exclusive regions. 278 /// 279 /// 1. Find the first common of `a` and `b` (ancestor) that implements 280 /// RegionBranchOpInterface. 281 /// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are 282 /// contained. 283 /// 3. Check if `regionA` and `regionB` are mutually exclusive. They are 284 /// mutually exclusive if they are not reachable from each other as per 285 /// RegionBranchOpInterface::getSuccessorRegions. 286 bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) { 287 assert(a && "expected non-empty operation"); 288 assert(b && "expected non-empty operation"); 289 290 auto branchOp = a->getParentOfType<RegionBranchOpInterface>(); 291 while (branchOp) { 292 // Check if b is inside branchOp. (We already know that a is.) 293 if (!branchOp->isProperAncestor(b)) { 294 // Check next enclosing RegionBranchOpInterface. 295 branchOp = branchOp->getParentOfType<RegionBranchOpInterface>(); 296 continue; 297 } 298 299 // b is contained in branchOp. Retrieve the regions in which `a` and `b` 300 // are contained. 301 Region *regionA = nullptr, *regionB = nullptr; 302 for (Region &r : branchOp->getRegions()) { 303 if (r.findAncestorOpInRegion(*a)) { 304 assert(!regionA && "already found a region for a"); 305 regionA = &r; 306 } 307 if (r.findAncestorOpInRegion(*b)) { 308 assert(!regionB && "already found a region for b"); 309 regionB = &r; 310 } 311 } 312 assert(regionA && regionB && "could not find region of op"); 313 314 // `a` and `b` are in mutually exclusive regions if both regions are 315 // distinct and neither region is reachable from the other region. 316 return regionA != regionB && !isRegionReachable(regionA, regionB) && 317 !isRegionReachable(regionB, regionA); 318 } 319 320 // Could not find a common RegionBranchOpInterface among a's and b's 321 // ancestors. 322 return false; 323 } 324 325 bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) { 326 Region *region = &getOperation()->getRegion(index); 327 return isRegionReachable(region, region); 328 } 329 330 Region *mlir::getEnclosingRepetitiveRegion(Operation *op) { 331 while (Region *region = op->getParentRegion()) { 332 op = region->getParentOp(); 333 if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) 334 if (branchOp.isRepetitiveRegion(region->getRegionNumber())) 335 return region; 336 } 337 return nullptr; 338 } 339 340 Region *mlir::getEnclosingRepetitiveRegion(Value value) { 341 Region *region = value.getParentRegion(); 342 while (region) { 343 Operation *op = region->getParentOp(); 344 if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) 345 if (branchOp.isRepetitiveRegion(region->getRegionNumber())) 346 return region; 347 region = op->getParentRegion(); 348 } 349 return nullptr; 350 } 351 352 //===----------------------------------------------------------------------===// 353 // RegionBranchTerminatorOpInterface 354 //===----------------------------------------------------------------------===// 355 356 /// Returns true if the given operation is either annotated with the 357 /// `ReturnLike` trait or implements the `RegionBranchTerminatorOpInterface`. 358 bool mlir::isRegionReturnLike(Operation *operation) { 359 return dyn_cast<RegionBranchTerminatorOpInterface>(operation) || 360 operation->hasTrait<OpTrait::ReturnLike>(); 361 } 362 363 /// Returns the mutable operands that are passed to the region with the given 364 /// `regionIndex`. If the operation does not implement the 365 /// `RegionBranchTerminatorOpInterface` and is not marked as `ReturnLike`, the 366 /// result will be `llvm::None`. In all other cases, the resulting 367 /// `OperandRange` represents all operands that are passed to the specified 368 /// successor region. If `regionIndex` is `llvm::None`, all operands that are 369 /// passed to the parent operation will be returned. 370 Optional<MutableOperandRange> 371 mlir::getMutableRegionBranchSuccessorOperands(Operation *operation, 372 Optional<unsigned> regionIndex) { 373 // Try to query a RegionBranchTerminatorOpInterface to determine 374 // all successor operands that will be passed to the successor 375 // input arguments. 376 if (auto regionTerminatorInterface = 377 dyn_cast<RegionBranchTerminatorOpInterface>(operation)) 378 return regionTerminatorInterface.getMutableSuccessorOperands(regionIndex); 379 380 // TODO: The ReturnLike trait should imply a default implementation of the 381 // RegionBranchTerminatorOpInterface. This would make this code significantly 382 // easier. Furthermore, this may even make this function obsolete. 383 if (operation->hasTrait<OpTrait::ReturnLike>()) 384 return MutableOperandRange(operation); 385 return llvm::None; 386 } 387 388 /// Returns the read only operands that are passed to the region with the given 389 /// `regionIndex`. See `getMutableRegionBranchSuccessorOperands` for more 390 /// information. 391 Optional<OperandRange> 392 mlir::getRegionBranchSuccessorOperands(Operation *operation, 393 Optional<unsigned> regionIndex) { 394 auto range = getMutableRegionBranchSuccessorOperands(operation, regionIndex); 395 return range ? Optional<OperandRange>(*range) : llvm::None; 396 } 397