1 //===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "mlir/IR/BuiltinTypes.h" 12 #include "mlir/Interfaces/ControlFlowInterfaces.h" 13 #include "llvm/ADT/SmallPtrSet.h" 14 15 using namespace mlir; 16 17 //===----------------------------------------------------------------------===// 18 // ControlFlowInterfaces 19 //===----------------------------------------------------------------------===// 20 21 #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc" 22 23 SuccessorOperands::SuccessorOperands(MutableOperandRange forwardedOperands) 24 : producedOperandCount(0), forwardedOperands(std::move(forwardedOperands)) { 25 } 26 27 SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount, 28 MutableOperandRange forwardedOperands) 29 : producedOperandCount(producedOperandCount), 30 forwardedOperands(std::move(forwardedOperands)) {} 31 32 //===----------------------------------------------------------------------===// 33 // BranchOpInterface 34 //===----------------------------------------------------------------------===// 35 36 /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some 37 /// successor if 'operandIndex' is within the range of 'operands', or None if 38 /// `operandIndex` isn't a successor operand index. 39 Optional<BlockArgument> 40 detail::getBranchSuccessorArgument(const SuccessorOperands &operands, 41 unsigned operandIndex, Block *successor) { 42 OperandRange forwardedOperands = operands.getForwardedOperands(); 43 // Check that the operands are valid. 44 if (forwardedOperands.empty()) 45 return llvm::None; 46 47 // Check to ensure that this operand is within the range. 48 unsigned operandsStart = forwardedOperands.getBeginOperandIndex(); 49 if (operandIndex < operandsStart || 50 operandIndex >= (operandsStart + forwardedOperands.size())) 51 return llvm::None; 52 53 // Index the successor. 54 unsigned argIndex = 55 operands.getProducedOperandCount() + operandIndex - operandsStart; 56 return successor->getArgument(argIndex); 57 } 58 59 /// Verify that the given operands match those of the given successor block. 60 LogicalResult 61 detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo, 62 const SuccessorOperands &operands) { 63 // Check the count. 64 unsigned operandCount = operands.size(); 65 Block *destBB = op->getSuccessor(succNo); 66 if (operandCount != destBB->getNumArguments()) 67 return op->emitError() << "branch has " << operandCount 68 << " operands for successor #" << succNo 69 << ", but target block has " 70 << destBB->getNumArguments(); 71 72 // Check the types. 73 for (unsigned i = operands.getProducedOperandCount(); i != operandCount; 74 ++i) { 75 if (!cast<BranchOpInterface>(op).areTypesCompatible( 76 operands[i].getType(), destBB->getArgument(i).getType())) 77 return op->emitError() << "type mismatch for bb argument #" << i 78 << " of successor #" << succNo; 79 } 80 return success(); 81 } 82 83 //===----------------------------------------------------------------------===// 84 // RegionBranchOpInterface 85 //===----------------------------------------------------------------------===// 86 87 /// Verify that types match along all region control flow edges originating from 88 /// `sourceNo` (region # if source is a region, llvm::None if source is parent 89 /// op). `getInputsTypesForRegion` is a function that returns the types of the 90 /// inputs that flow from `sourceIndex' to the given region, or llvm::None if 91 /// the exact type match verification is not necessary (e.g., if the Op verifies 92 /// the match itself). 93 static LogicalResult 94 verifyTypesAlongAllEdges(Operation *op, Optional<unsigned> sourceNo, 95 function_ref<Optional<TypeRange>(Optional<unsigned>)> 96 getInputsTypesForRegion) { 97 auto regionInterface = cast<RegionBranchOpInterface>(op); 98 99 SmallVector<RegionSuccessor, 2> successors; 100 regionInterface.getSuccessorRegions(sourceNo, successors); 101 102 for (RegionSuccessor &succ : successors) { 103 Optional<unsigned> succRegionNo; 104 if (!succ.isParent()) 105 succRegionNo = succ.getSuccessor()->getRegionNumber(); 106 107 auto printEdgeName = [&](InFlightDiagnostic &diag) -> InFlightDiagnostic & { 108 diag << "from "; 109 if (sourceNo) 110 diag << "Region #" << sourceNo.value(); 111 else 112 diag << "parent operands"; 113 114 diag << " to "; 115 if (succRegionNo) 116 diag << "Region #" << succRegionNo.value(); 117 else 118 diag << "parent results"; 119 return diag; 120 }; 121 122 Optional<TypeRange> sourceTypes = getInputsTypesForRegion(succRegionNo); 123 if (!sourceTypes.has_value()) 124 continue; 125 126 TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes(); 127 if (sourceTypes->size() != succInputsTypes.size()) { 128 InFlightDiagnostic diag = op->emitOpError(" region control flow edge "); 129 return printEdgeName(diag) << ": source has " << sourceTypes->size() 130 << " operands, but target successor needs " 131 << succInputsTypes.size(); 132 } 133 134 for (const auto &typesIdx : 135 llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) { 136 Type sourceType = std::get<0>(typesIdx.value()); 137 Type inputType = std::get<1>(typesIdx.value()); 138 if (!regionInterface.areTypesCompatible(sourceType, inputType)) { 139 InFlightDiagnostic diag = op->emitOpError(" along control flow edge "); 140 return printEdgeName(diag) 141 << ": source type #" << typesIdx.index() << " " << sourceType 142 << " should match input type #" << typesIdx.index() << " " 143 << inputType; 144 } 145 } 146 } 147 return success(); 148 } 149 150 /// Verify that types match along control flow edges described the given op. 151 LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) { 152 auto regionInterface = cast<RegionBranchOpInterface>(op); 153 154 auto inputTypesFromParent = [&](Optional<unsigned> regionNo) -> TypeRange { 155 return regionInterface.getSuccessorEntryOperands(regionNo).getTypes(); 156 }; 157 158 // Verify types along control flow edges originating from the parent. 159 if (failed(verifyTypesAlongAllEdges(op, llvm::None, inputTypesFromParent))) 160 return failure(); 161 162 // RegionBranchOpInterface should not be implemented by Ops that do not have 163 // attached regions. 164 assert(op->getNumRegions() != 0); 165 166 auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) { 167 if (lhs.size() != rhs.size()) 168 return false; 169 for (auto types : llvm::zip(lhs, rhs)) { 170 if (!regionInterface.areTypesCompatible(std::get<0>(types), 171 std::get<1>(types))) { 172 return false; 173 } 174 } 175 return true; 176 }; 177 178 // Verify types along control flow edges originating from each region. 179 for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) { 180 Region ®ion = op->getRegion(regionNo); 181 182 // Since there can be multiple `ReturnLike` terminators or others 183 // implementing the `RegionBranchTerminatorOpInterface`, all should have the 184 // same operand types when passing them to the same region. 185 186 Optional<OperandRange> regionReturnOperands; 187 for (Block &block : region) { 188 Operation *terminator = block.getTerminator(); 189 auto terminatorOperands = 190 getRegionBranchSuccessorOperands(terminator, regionNo); 191 if (!terminatorOperands) 192 continue; 193 194 if (!regionReturnOperands) { 195 regionReturnOperands = terminatorOperands; 196 continue; 197 } 198 199 // Found more than one ReturnLike terminator. Make sure the operand types 200 // match with the first one. 201 if (!areTypesCompatible(regionReturnOperands->getTypes(), 202 terminatorOperands->getTypes())) 203 return op->emitOpError("Region #") 204 << regionNo 205 << " operands mismatch between return-like terminators"; 206 } 207 208 auto inputTypesFromRegion = 209 [&](Optional<unsigned> regionNo) -> Optional<TypeRange> { 210 // If there is no return-like terminator, the op itself should verify 211 // type consistency. 212 if (!regionReturnOperands) 213 return llvm::None; 214 215 // All successors get the same set of operand types. 216 return TypeRange(regionReturnOperands->getTypes()); 217 }; 218 219 if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesFromRegion))) 220 return failure(); 221 } 222 223 return success(); 224 } 225 226 /// Return `true` if region `r` is reachable from region `begin` according to 227 /// the RegionBranchOpInterface (by taking a branch). 228 static bool isRegionReachable(Region *begin, Region *r) { 229 assert(begin->getParentOp() == r->getParentOp() && 230 "expected that both regions belong to the same op"); 231 auto op = cast<RegionBranchOpInterface>(begin->getParentOp()); 232 SmallVector<bool> visited(op->getNumRegions(), false); 233 visited[begin->getRegionNumber()] = true; 234 235 // Retrieve all successors of the region and enqueue them in the worklist. 236 SmallVector<unsigned> worklist; 237 auto enqueueAllSuccessors = [&](unsigned index) { 238 SmallVector<RegionSuccessor> successors; 239 op.getSuccessorRegions(index, successors); 240 for (RegionSuccessor successor : successors) 241 if (!successor.isParent()) 242 worklist.push_back(successor.getSuccessor()->getRegionNumber()); 243 }; 244 enqueueAllSuccessors(begin->getRegionNumber()); 245 246 // Process all regions in the worklist via DFS. 247 while (!worklist.empty()) { 248 unsigned nextRegion = worklist.pop_back_val(); 249 if (nextRegion == r->getRegionNumber()) 250 return true; 251 if (visited[nextRegion]) 252 continue; 253 visited[nextRegion] = true; 254 enqueueAllSuccessors(nextRegion); 255 } 256 257 return false; 258 } 259 260 /// Return `true` if `a` and `b` are in mutually exclusive regions. 261 /// 262 /// 1. Find the first common of `a` and `b` (ancestor) that implements 263 /// RegionBranchOpInterface. 264 /// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are 265 /// contained. 266 /// 3. Check if `regionA` and `regionB` are mutually exclusive. They are 267 /// mutually exclusive if they are not reachable from each other as per 268 /// RegionBranchOpInterface::getSuccessorRegions. 269 bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) { 270 assert(a && "expected non-empty operation"); 271 assert(b && "expected non-empty operation"); 272 273 auto branchOp = a->getParentOfType<RegionBranchOpInterface>(); 274 while (branchOp) { 275 // Check if b is inside branchOp. (We already know that a is.) 276 if (!branchOp->isProperAncestor(b)) { 277 // Check next enclosing RegionBranchOpInterface. 278 branchOp = branchOp->getParentOfType<RegionBranchOpInterface>(); 279 continue; 280 } 281 282 // b is contained in branchOp. Retrieve the regions in which `a` and `b` 283 // are contained. 284 Region *regionA = nullptr, *regionB = nullptr; 285 for (Region &r : branchOp->getRegions()) { 286 if (r.findAncestorOpInRegion(*a)) { 287 assert(!regionA && "already found a region for a"); 288 regionA = &r; 289 } 290 if (r.findAncestorOpInRegion(*b)) { 291 assert(!regionB && "already found a region for b"); 292 regionB = &r; 293 } 294 } 295 assert(regionA && regionB && "could not find region of op"); 296 297 // `a` and `b` are in mutually exclusive regions if both regions are 298 // distinct and neither region is reachable from the other region. 299 return regionA != regionB && !isRegionReachable(regionA, regionB) && 300 !isRegionReachable(regionB, regionA); 301 } 302 303 // Could not find a common RegionBranchOpInterface among a's and b's 304 // ancestors. 305 return false; 306 } 307 308 bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) { 309 Region *region = &getOperation()->getRegion(index); 310 return isRegionReachable(region, region); 311 } 312 313 void RegionBranchOpInterface::getSuccessorRegions( 314 Optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) { 315 unsigned numInputs = 0; 316 if (index) { 317 // If the predecessor is a region, get the number of operands from an 318 // exiting terminator in the region. 319 for (Block &block : getOperation()->getRegion(*index)) { 320 Operation *terminator = block.getTerminator(); 321 if (getRegionBranchSuccessorOperands(terminator, *index)) { 322 numInputs = terminator->getNumOperands(); 323 break; 324 } 325 } 326 } else { 327 // Otherwise, use the number of parent operation operands. 328 numInputs = getOperation()->getNumOperands(); 329 } 330 SmallVector<Attribute, 2> operands(numInputs, nullptr); 331 getSuccessorRegions(index, operands, regions); 332 } 333 334 Region *mlir::getEnclosingRepetitiveRegion(Operation *op) { 335 while (Region *region = op->getParentRegion()) { 336 op = region->getParentOp(); 337 if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) 338 if (branchOp.isRepetitiveRegion(region->getRegionNumber())) 339 return region; 340 } 341 return nullptr; 342 } 343 344 Region *mlir::getEnclosingRepetitiveRegion(Value value) { 345 Region *region = value.getParentRegion(); 346 while (region) { 347 Operation *op = region->getParentOp(); 348 if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) 349 if (branchOp.isRepetitiveRegion(region->getRegionNumber())) 350 return region; 351 region = op->getParentRegion(); 352 } 353 return nullptr; 354 } 355 356 //===----------------------------------------------------------------------===// 357 // RegionBranchTerminatorOpInterface 358 //===----------------------------------------------------------------------===// 359 360 /// Returns true if the given operation is either annotated with the 361 /// `ReturnLike` trait or implements the `RegionBranchTerminatorOpInterface`. 362 bool mlir::isRegionReturnLike(Operation *operation) { 363 return dyn_cast<RegionBranchTerminatorOpInterface>(operation) || 364 operation->hasTrait<OpTrait::ReturnLike>(); 365 } 366 367 /// Returns the mutable operands that are passed to the region with the given 368 /// `regionIndex`. If the operation does not implement the 369 /// `RegionBranchTerminatorOpInterface` and is not marked as `ReturnLike`, the 370 /// result will be `llvm::None`. In all other cases, the resulting 371 /// `OperandRange` represents all operands that are passed to the specified 372 /// successor region. If `regionIndex` is `llvm::None`, all operands that are 373 /// passed to the parent operation will be returned. 374 Optional<MutableOperandRange> 375 mlir::getMutableRegionBranchSuccessorOperands(Operation *operation, 376 Optional<unsigned> regionIndex) { 377 // Try to query a RegionBranchTerminatorOpInterface to determine 378 // all successor operands that will be passed to the successor 379 // input arguments. 380 if (auto regionTerminatorInterface = 381 dyn_cast<RegionBranchTerminatorOpInterface>(operation)) 382 return regionTerminatorInterface.getMutableSuccessorOperands(regionIndex); 383 384 // TODO: The ReturnLike trait should imply a default implementation of the 385 // RegionBranchTerminatorOpInterface. This would make this code significantly 386 // easier. Furthermore, this may even make this function obsolete. 387 if (operation->hasTrait<OpTrait::ReturnLike>()) 388 return MutableOperandRange(operation); 389 return llvm::None; 390 } 391 392 /// Returns the read only operands that are passed to the region with the given 393 /// `regionIndex`. See `getMutableRegionBranchSuccessorOperands` for more 394 /// information. 395 Optional<OperandRange> 396 mlir::getRegionBranchSuccessorOperands(Operation *operation, 397 Optional<unsigned> regionIndex) { 398 auto range = getMutableRegionBranchSuccessorOperands(operation, regionIndex); 399 return range ? Optional<OperandRange>(*range) : llvm::None; 400 } 401