1 //===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "mlir/Interfaces/ControlFlowInterfaces.h" 12 #include "llvm/ADT/SmallPtrSet.h" 13 14 using namespace mlir; 15 16 //===----------------------------------------------------------------------===// 17 // ControlFlowInterfaces 18 //===----------------------------------------------------------------------===// 19 20 #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc" 21 22 SuccessorOperands::SuccessorOperands(MutableOperandRange forwardedOperands) 23 : producedOperandCount(0), forwardedOperands(std::move(forwardedOperands)) { 24 } 25 26 SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount, 27 MutableOperandRange forwardedOperands) 28 : producedOperandCount(producedOperandCount), 29 forwardedOperands(std::move(forwardedOperands)) {} 30 31 //===----------------------------------------------------------------------===// 32 // BranchOpInterface 33 //===----------------------------------------------------------------------===// 34 35 /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some 36 /// successor if 'operandIndex' is within the range of 'operands', or None if 37 /// `operandIndex` isn't a successor operand index. 38 Optional<BlockArgument> 39 detail::getBranchSuccessorArgument(const SuccessorOperands &operands, 40 unsigned operandIndex, Block *successor) { 41 OperandRange forwardedOperands = operands.getForwardedOperands(); 42 // Check that the operands are valid. 43 if (forwardedOperands.empty()) 44 return llvm::None; 45 46 // Check to ensure that this operand is within the range. 47 unsigned operandsStart = forwardedOperands.getBeginOperandIndex(); 48 if (operandIndex < operandsStart || 49 operandIndex >= (operandsStart + forwardedOperands.size())) 50 return llvm::None; 51 52 // Index the successor. 53 unsigned argIndex = 54 operands.getProducedOperandCount() + operandIndex - operandsStart; 55 return successor->getArgument(argIndex); 56 } 57 58 /// Verify that the given operands match those of the given successor block. 59 LogicalResult 60 detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo, 61 const SuccessorOperands &operands) { 62 // Check the count. 63 unsigned operandCount = operands.size(); 64 Block *destBB = op->getSuccessor(succNo); 65 if (operandCount != destBB->getNumArguments()) 66 return op->emitError() << "branch has " << operandCount 67 << " operands for successor #" << succNo 68 << ", but target block has " 69 << destBB->getNumArguments(); 70 71 // Check the types. 72 for (unsigned i = operands.getProducedOperandCount(); i != operandCount; 73 ++i) { 74 if (!cast<BranchOpInterface>(op).areTypesCompatible( 75 operands[i].getType(), destBB->getArgument(i).getType())) 76 return op->emitError() << "type mismatch for bb argument #" << i 77 << " of successor #" << succNo; 78 } 79 return success(); 80 } 81 82 //===----------------------------------------------------------------------===// 83 // RegionBranchOpInterface 84 //===----------------------------------------------------------------------===// 85 86 /// Verify that types match along all region control flow edges originating from 87 /// `sourceNo` (region # if source is a region, llvm::None if source is parent 88 /// op). `getInputsTypesForRegion` is a function that returns the types of the 89 /// inputs that flow from `sourceIndex' to the given region, or llvm::None if 90 /// the exact type match verification is not necessary (e.g., if the Op verifies 91 /// the match itself). 92 static LogicalResult 93 verifyTypesAlongAllEdges(Operation *op, Optional<unsigned> sourceNo, 94 function_ref<Optional<TypeRange>(Optional<unsigned>)> 95 getInputsTypesForRegion) { 96 auto regionInterface = cast<RegionBranchOpInterface>(op); 97 98 SmallVector<RegionSuccessor, 2> successors; 99 regionInterface.getSuccessorRegions(sourceNo, successors); 100 101 for (RegionSuccessor &succ : successors) { 102 Optional<unsigned> succRegionNo; 103 if (!succ.isParent()) 104 succRegionNo = succ.getSuccessor()->getRegionNumber(); 105 106 auto printEdgeName = [&](InFlightDiagnostic &diag) -> InFlightDiagnostic & { 107 diag << "from "; 108 if (sourceNo) 109 diag << "Region #" << sourceNo.getValue(); 110 else 111 diag << "parent operands"; 112 113 diag << " to "; 114 if (succRegionNo) 115 diag << "Region #" << succRegionNo.getValue(); 116 else 117 diag << "parent results"; 118 return diag; 119 }; 120 121 Optional<TypeRange> sourceTypes = getInputsTypesForRegion(succRegionNo); 122 if (!sourceTypes.hasValue()) 123 continue; 124 125 TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes(); 126 if (sourceTypes->size() != succInputsTypes.size()) { 127 InFlightDiagnostic diag = op->emitOpError(" region control flow edge "); 128 return printEdgeName(diag) << ": source has " << sourceTypes->size() 129 << " operands, but target successor needs " 130 << succInputsTypes.size(); 131 } 132 133 for (const auto &typesIdx : 134 llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) { 135 Type sourceType = std::get<0>(typesIdx.value()); 136 Type inputType = std::get<1>(typesIdx.value()); 137 if (!regionInterface.areTypesCompatible(sourceType, inputType)) { 138 InFlightDiagnostic diag = op->emitOpError(" along control flow edge "); 139 return printEdgeName(diag) 140 << ": source type #" << typesIdx.index() << " " << sourceType 141 << " should match input type #" << typesIdx.index() << " " 142 << inputType; 143 } 144 } 145 } 146 return success(); 147 } 148 149 /// Verify that types match along control flow edges described the given op. 150 LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) { 151 auto regionInterface = cast<RegionBranchOpInterface>(op); 152 153 auto inputTypesFromParent = [&](Optional<unsigned> regionNo) -> TypeRange { 154 if (regionNo.hasValue()) { 155 return regionInterface.getSuccessorEntryOperands(regionNo.getValue()) 156 .getTypes(); 157 } 158 159 // If the successor of a parent op is the parent itself 160 // RegionBranchOpInterface does not have an API to query what the entry 161 // operands will be in that case. Vend out the result types of the op in 162 // that case so that type checking succeeds for this case. 163 return op->getResultTypes(); 164 }; 165 166 // Verify types along control flow edges originating from the parent. 167 if (failed(verifyTypesAlongAllEdges(op, llvm::None, inputTypesFromParent))) 168 return failure(); 169 170 // RegionBranchOpInterface should not be implemented by Ops that do not have 171 // attached regions. 172 assert(op->getNumRegions() != 0); 173 174 auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) { 175 if (lhs.size() != rhs.size()) 176 return false; 177 for (auto types : llvm::zip(lhs, rhs)) { 178 if (!regionInterface.areTypesCompatible(std::get<0>(types), 179 std::get<1>(types))) { 180 return false; 181 } 182 } 183 return true; 184 }; 185 186 // Verify types along control flow edges originating from each region. 187 for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) { 188 Region ®ion = op->getRegion(regionNo); 189 190 // Since there can be multiple `ReturnLike` terminators or others 191 // implementing the `RegionBranchTerminatorOpInterface`, all should have the 192 // same operand types when passing them to the same region. 193 194 Optional<OperandRange> regionReturnOperands; 195 for (Block &block : region) { 196 Operation *terminator = block.getTerminator(); 197 auto terminatorOperands = 198 getRegionBranchSuccessorOperands(terminator, regionNo); 199 if (!terminatorOperands) 200 continue; 201 202 if (!regionReturnOperands) { 203 regionReturnOperands = terminatorOperands; 204 continue; 205 } 206 207 // Found more than one ReturnLike terminator. Make sure the operand types 208 // match with the first one. 209 if (!areTypesCompatible(regionReturnOperands->getTypes(), 210 terminatorOperands->getTypes())) 211 return op->emitOpError("Region #") 212 << regionNo 213 << " operands mismatch between return-like terminators"; 214 } 215 216 auto inputTypesFromRegion = 217 [&](Optional<unsigned> regionNo) -> Optional<TypeRange> { 218 // If there is no return-like terminator, the op itself should verify 219 // type consistency. 220 if (!regionReturnOperands) 221 return llvm::None; 222 223 // All successors get the same set of operand types. 224 return TypeRange(regionReturnOperands->getTypes()); 225 }; 226 227 if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesFromRegion))) 228 return failure(); 229 } 230 231 return success(); 232 } 233 234 /// Return `true` if region `r` is reachable from region `begin` according to 235 /// the RegionBranchOpInterface (by taking a branch). 236 static bool isRegionReachable(Region *begin, Region *r) { 237 assert(begin->getParentOp() == r->getParentOp() && 238 "expected that both regions belong to the same op"); 239 auto op = cast<RegionBranchOpInterface>(begin->getParentOp()); 240 SmallVector<bool> visited(op->getNumRegions(), false); 241 visited[begin->getRegionNumber()] = true; 242 243 // Retrieve all successors of the region and enqueue them in the worklist. 244 SmallVector<unsigned> worklist; 245 auto enqueueAllSuccessors = [&](unsigned index) { 246 SmallVector<RegionSuccessor> successors; 247 op.getSuccessorRegions(index, successors); 248 for (RegionSuccessor successor : successors) 249 if (!successor.isParent()) 250 worklist.push_back(successor.getSuccessor()->getRegionNumber()); 251 }; 252 enqueueAllSuccessors(begin->getRegionNumber()); 253 254 // Process all regions in the worklist via DFS. 255 while (!worklist.empty()) { 256 unsigned nextRegion = worklist.pop_back_val(); 257 if (nextRegion == r->getRegionNumber()) 258 return true; 259 if (visited[nextRegion]) 260 continue; 261 visited[nextRegion] = true; 262 enqueueAllSuccessors(nextRegion); 263 } 264 265 return false; 266 } 267 268 /// Return `true` if `a` and `b` are in mutually exclusive regions. 269 /// 270 /// 1. Find the first common of `a` and `b` (ancestor) that implements 271 /// RegionBranchOpInterface. 272 /// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are 273 /// contained. 274 /// 3. Check if `regionA` and `regionB` are mutually exclusive. They are 275 /// mutually exclusive if they are not reachable from each other as per 276 /// RegionBranchOpInterface::getSuccessorRegions. 277 bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) { 278 assert(a && "expected non-empty operation"); 279 assert(b && "expected non-empty operation"); 280 281 auto branchOp = a->getParentOfType<RegionBranchOpInterface>(); 282 while (branchOp) { 283 // Check if b is inside branchOp. (We already know that a is.) 284 if (!branchOp->isProperAncestor(b)) { 285 // Check next enclosing RegionBranchOpInterface. 286 branchOp = branchOp->getParentOfType<RegionBranchOpInterface>(); 287 continue; 288 } 289 290 // b is contained in branchOp. Retrieve the regions in which `a` and `b` 291 // are contained. 292 Region *regionA = nullptr, *regionB = nullptr; 293 for (Region &r : branchOp->getRegions()) { 294 if (r.findAncestorOpInRegion(*a)) { 295 assert(!regionA && "already found a region for a"); 296 regionA = &r; 297 } 298 if (r.findAncestorOpInRegion(*b)) { 299 assert(!regionB && "already found a region for b"); 300 regionB = &r; 301 } 302 } 303 assert(regionA && regionB && "could not find region of op"); 304 305 // `a` and `b` are in mutually exclusive regions if both regions are 306 // distinct and neither region is reachable from the other region. 307 return regionA != regionB && !isRegionReachable(regionA, regionB) && 308 !isRegionReachable(regionB, regionA); 309 } 310 311 // Could not find a common RegionBranchOpInterface among a's and b's 312 // ancestors. 313 return false; 314 } 315 316 bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) { 317 Region *region = &getOperation()->getRegion(index); 318 return isRegionReachable(region, region); 319 } 320 321 void RegionBranchOpInterface::getSuccessorRegions( 322 Optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) { 323 unsigned numInputs = 0; 324 if (index) { 325 // If the predecessor is a region, get the number of operands from an 326 // exiting terminator in the region. 327 for (Block &block : getOperation()->getRegion(*index)) { 328 Operation *terminator = block.getTerminator(); 329 if (getRegionBranchSuccessorOperands(terminator, *index)) { 330 numInputs = terminator->getNumOperands(); 331 break; 332 } 333 } 334 } else { 335 // Otherwise, use the number of parent operation operands. 336 numInputs = getOperation()->getNumOperands(); 337 } 338 SmallVector<Attribute, 2> operands(numInputs, nullptr); 339 getSuccessorRegions(index, operands, regions); 340 } 341 342 Region *mlir::getEnclosingRepetitiveRegion(Operation *op) { 343 while (Region *region = op->getParentRegion()) { 344 op = region->getParentOp(); 345 if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) 346 if (branchOp.isRepetitiveRegion(region->getRegionNumber())) 347 return region; 348 } 349 return nullptr; 350 } 351 352 Region *mlir::getEnclosingRepetitiveRegion(Value value) { 353 Region *region = value.getParentRegion(); 354 while (region) { 355 Operation *op = region->getParentOp(); 356 if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) 357 if (branchOp.isRepetitiveRegion(region->getRegionNumber())) 358 return region; 359 region = op->getParentRegion(); 360 } 361 return nullptr; 362 } 363 364 //===----------------------------------------------------------------------===// 365 // RegionBranchTerminatorOpInterface 366 //===----------------------------------------------------------------------===// 367 368 /// Returns true if the given operation is either annotated with the 369 /// `ReturnLike` trait or implements the `RegionBranchTerminatorOpInterface`. 370 bool mlir::isRegionReturnLike(Operation *operation) { 371 return dyn_cast<RegionBranchTerminatorOpInterface>(operation) || 372 operation->hasTrait<OpTrait::ReturnLike>(); 373 } 374 375 /// Returns the mutable operands that are passed to the region with the given 376 /// `regionIndex`. If the operation does not implement the 377 /// `RegionBranchTerminatorOpInterface` and is not marked as `ReturnLike`, the 378 /// result will be `llvm::None`. In all other cases, the resulting 379 /// `OperandRange` represents all operands that are passed to the specified 380 /// successor region. If `regionIndex` is `llvm::None`, all operands that are 381 /// passed to the parent operation will be returned. 382 Optional<MutableOperandRange> 383 mlir::getMutableRegionBranchSuccessorOperands(Operation *operation, 384 Optional<unsigned> regionIndex) { 385 // Try to query a RegionBranchTerminatorOpInterface to determine 386 // all successor operands that will be passed to the successor 387 // input arguments. 388 if (auto regionTerminatorInterface = 389 dyn_cast<RegionBranchTerminatorOpInterface>(operation)) 390 return regionTerminatorInterface.getMutableSuccessorOperands(regionIndex); 391 392 // TODO: The ReturnLike trait should imply a default implementation of the 393 // RegionBranchTerminatorOpInterface. This would make this code significantly 394 // easier. Furthermore, this may even make this function obsolete. 395 if (operation->hasTrait<OpTrait::ReturnLike>()) 396 return MutableOperandRange(operation); 397 return llvm::None; 398 } 399 400 /// Returns the read only operands that are passed to the region with the given 401 /// `regionIndex`. See `getMutableRegionBranchSuccessorOperands` for more 402 /// information. 403 Optional<OperandRange> 404 mlir::getRegionBranchSuccessorOperands(Operation *operation, 405 Optional<unsigned> regionIndex) { 406 auto range = getMutableRegionBranchSuccessorOperands(operation, regionIndex); 407 return range ? Optional<OperandRange>(*range) : llvm::None; 408 } 409