1 //===- DataFlowAnalysis.cpp -----------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Analysis/DataFlowAnalysis.h" 10 #include "mlir/IR/Operation.h" 11 #include "mlir/Interfaces/CallInterfaces.h" 12 #include "mlir/Interfaces/ControlFlowInterfaces.h" 13 #include "llvm/ADT/ArrayRef.h" 14 #include "llvm/ADT/SmallPtrSet.h" 15 16 #include <queue> 17 18 using namespace mlir; 19 using namespace mlir::detail; 20 21 namespace { 22 /// This class contains various state used when computing the lattice elements 23 /// of a callable operation. 24 class CallableLatticeState { 25 public: 26 /// Build a lattice state with a given callable region, and a specified number 27 /// of results to be initialized to the default lattice element. 28 CallableLatticeState(ForwardDataFlowAnalysisBase &analysis, 29 Region *callableRegion, unsigned numResults) 30 : callableArguments(callableRegion->getArguments()), 31 resultLatticeElements(numResults) { 32 for (AbstractLatticeElement *&it : resultLatticeElements) 33 it = analysis.createLatticeElement(); 34 } 35 36 /// Returns the arguments to the callable region. 37 Block::BlockArgListType getCallableArguments() const { 38 return callableArguments; 39 } 40 41 /// Returns the lattice element for the results of the callable region. 42 auto getResultLatticeElements() { 43 return llvm::make_pointee_range(resultLatticeElements); 44 } 45 46 /// Add a call to this callable. This is only used if the callable defines a 47 /// symbol. 48 void addSymbolCall(Operation *op) { symbolCalls.push_back(op); } 49 50 /// Return the calls that reference this callable. This is only used 51 /// if the callable defines a symbol. 52 ArrayRef<Operation *> getSymbolCalls() const { return symbolCalls; } 53 54 private: 55 /// The arguments of the callable region. 56 Block::BlockArgListType callableArguments; 57 58 /// The lattice state for each of the results of this region. The return 59 /// values of the callable aren't SSA values, so we need to track them 60 /// separately. 61 SmallVector<AbstractLatticeElement *, 4> resultLatticeElements; 62 63 /// The calls referencing this callable if this callable defines a symbol. 64 /// This removes the need to recompute symbol references during propagation. 65 /// Value based references are trivial to resolve, so they can be done 66 /// in-place. 67 SmallVector<Operation *, 4> symbolCalls; 68 }; 69 70 /// This class represents the solver for a forward dataflow analysis. This class 71 /// acts as the propagation engine for computing which lattice elements. 72 class ForwardDataFlowSolver { 73 public: 74 /// Initialize the solver with the given top-level operation. 75 ForwardDataFlowSolver(ForwardDataFlowAnalysisBase &analysis, Operation *op); 76 77 /// Run the solver until it converges. 78 void solve(); 79 80 private: 81 /// Initialize the set of symbol defining callables that can have their 82 /// arguments and results tracked. 'op' is the top-level operation that the 83 /// solver is operating on. 84 void initializeSymbolCallables(Operation *op); 85 86 /// Visit the users of the given IR that reside within executable blocks. 87 template <typename T> 88 void visitUsers(T &value) { 89 for (Operation *user : value.getUsers()) 90 if (isBlockExecutable(user->getBlock())) 91 visitOperation(user); 92 } 93 94 /// Visit the given operation and compute any necessary lattice state. 95 void visitOperation(Operation *op); 96 97 /// Visit the given call operation and compute any necessary lattice state. 98 void visitCallOperation(CallOpInterface op); 99 100 /// Visit the given callable operation and compute any necessary lattice 101 /// state. 102 void visitCallableOperation(Operation *op); 103 104 /// Visit the given region branch operation, which defines regions, and 105 /// compute any necessary lattice state. This also resolves the lattice state 106 /// of both the operation results and any nested regions. 107 void visitRegionBranchOperation( 108 RegionBranchOpInterface branch, 109 ArrayRef<AbstractLatticeElement *> operandLattices); 110 111 /// Visit the given set of region successors, computing any necessary lattice 112 /// state. The provided function returns the input operands to the region at 113 /// the given index. If the index is 'None', the input operands correspond to 114 /// the parent operation results. 115 void visitRegionSuccessors( 116 Operation *parentOp, ArrayRef<RegionSuccessor> regionSuccessors, 117 ArrayRef<AbstractLatticeElement *> operandLattices, 118 function_ref<OperandRange(Optional<unsigned>)> getInputsForRegion); 119 120 /// Visit the given terminator operation and compute any necessary lattice 121 /// state. 122 void 123 visitTerminatorOperation(Operation *op, 124 ArrayRef<AbstractLatticeElement *> operandLattices); 125 126 /// Visit the given terminator operation that exits a callable region. These 127 /// are terminators with no CFG successors. 128 void visitCallableTerminatorOperation( 129 Operation *callable, Operation *terminator, 130 ArrayRef<AbstractLatticeElement *> operandLattices); 131 132 /// Visit the given block and compute any necessary lattice state. 133 void visitBlock(Block *block); 134 135 /// Visit argument #'i' of the given block and compute any necessary lattice 136 /// state. 137 void visitBlockArgument(Block *block, int i); 138 139 /// Mark the entry block of the given region as executable. Returns NoChange 140 /// if the block was already marked executable. If `markPessimisticFixpoint` 141 /// is true, the arguments of the entry block are also marked as having 142 /// reached the pessimistic fixpoint. 143 ChangeResult markEntryBlockExecutable(Region *region, 144 bool markPessimisticFixpoint); 145 146 /// Mark the given block as executable. Returns NoChange if the block was 147 /// already marked executable. 148 ChangeResult markBlockExecutable(Block *block); 149 150 /// Returns true if the given block is executable. 151 bool isBlockExecutable(Block *block) const; 152 153 /// Mark the edge between 'from' and 'to' as executable. 154 void markEdgeExecutable(Block *from, Block *to); 155 156 /// Return true if the edge between 'from' and 'to' is executable. 157 bool isEdgeExecutable(Block *from, Block *to) const; 158 159 /// Mark the given value as having reached the pessimistic fixpoint. This 160 /// means that we cannot further refine the state of this value. 161 void markPessimisticFixpoint(Value value); 162 163 /// Mark all of the given values as having reaching the pessimistic fixpoint. 164 template <typename ValuesT> 165 void markAllPessimisticFixpoint(ValuesT values) { 166 for (auto value : values) 167 markPessimisticFixpoint(value); 168 } 169 template <typename ValuesT> 170 void markAllPessimisticFixpoint(Operation *op, ValuesT values) { 171 markAllPessimisticFixpoint(values); 172 opWorklist.push(op); 173 } 174 template <typename ValuesT> 175 void markAllPessimisticFixpointAndVisitUsers(ValuesT values) { 176 for (auto value : values) { 177 AbstractLatticeElement &lattice = analysis.getLatticeElement(value); 178 if (lattice.markPessimisticFixpoint() == ChangeResult::Change) 179 visitUsers(value); 180 } 181 } 182 183 /// Returns true if the given value was marked as having reached the 184 /// pessimistic fixpoint. 185 bool isAtFixpoint(Value value) const; 186 187 /// Merge in the given lattice 'from' into the lattice 'to'. 'owner' 188 /// corresponds to the parent operation of the lattice for 'to'. 189 void join(Operation *owner, AbstractLatticeElement &to, 190 const AbstractLatticeElement &from); 191 192 /// A reference to the dataflow analysis being computed. 193 ForwardDataFlowAnalysisBase &analysis; 194 195 /// The set of blocks that are known to execute, or are intrinsically live. 196 SmallPtrSet<Block *, 16> executableBlocks; 197 198 /// The set of control flow edges that are known to execute. 199 DenseSet<std::pair<Block *, Block *>> executableEdges; 200 201 /// A worklist containing blocks that need to be processed. 202 std::queue<Block *> blockWorklist; 203 204 /// A worklist of operations that need to be processed. 205 std::queue<Operation *> opWorklist; 206 207 /// The callable operations that have their argument/result state tracked. 208 DenseMap<Operation *, CallableLatticeState> callableLatticeState; 209 210 /// A map between a call operation and the resolved symbol callable. This 211 /// avoids re-resolving symbol references during propagation. Value based 212 /// callables are trivial to resolve, so they can be done in-place. 213 DenseMap<Operation *, Operation *> callToSymbolCallable; 214 215 /// A symbol table used for O(1) symbol lookups during simplification. 216 SymbolTableCollection symbolTable; 217 }; 218 } // namespace 219 220 ForwardDataFlowSolver::ForwardDataFlowSolver( 221 ForwardDataFlowAnalysisBase &analysis, Operation *op) 222 : analysis(analysis) { 223 /// Initialize the solver with the regions within this operation. 224 for (Region ®ion : op->getRegions()) { 225 // Mark the entry block as executable. The values passed to these regions 226 // are also invisible, so mark any arguments as reaching the pessimistic 227 // fixpoint. 228 markEntryBlockExecutable(®ion, /*markPessimisticFixpoint=*/true); 229 } 230 initializeSymbolCallables(op); 231 } 232 233 void ForwardDataFlowSolver::solve() { 234 while (!blockWorklist.empty() || !opWorklist.empty()) { 235 // Process any operations in the op worklist. 236 while (!opWorklist.empty()) { 237 Operation *nextOp = opWorklist.front(); 238 opWorklist.pop(); 239 visitUsers(*nextOp); 240 } 241 242 // Process any blocks in the block worklist. 243 while (!blockWorklist.empty()) { 244 Block *nextBlock = blockWorklist.front(); 245 blockWorklist.pop(); 246 visitBlock(nextBlock); 247 } 248 } 249 } 250 251 void ForwardDataFlowSolver::initializeSymbolCallables(Operation *op) { 252 // Initialize the set of symbol callables that can have their state tracked. 253 // This tracks which symbol callable operations we can propagate within and 254 // out of. 255 auto walkFn = [&](Operation *symTable, bool allUsesVisible) { 256 Region &symbolTableRegion = symTable->getRegion(0); 257 Block *symbolTableBlock = &symbolTableRegion.front(); 258 for (auto callable : symbolTableBlock->getOps<CallableOpInterface>()) { 259 // We won't be able to track external callables. 260 Region *callableRegion = callable.getCallableRegion(); 261 if (!callableRegion) 262 continue; 263 // We only care about symbol defining callables here. 264 auto symbol = dyn_cast<SymbolOpInterface>(callable.getOperation()); 265 if (!symbol) 266 continue; 267 callableLatticeState.try_emplace(callable, analysis, callableRegion, 268 callable.getCallableResults().size()); 269 270 // If not all of the uses of this symbol are visible, we can't track the 271 // state of the arguments. 272 if (symbol.isPublic() || (!allUsesVisible && symbol.isNested())) { 273 for (Region ®ion : callable->getRegions()) 274 markEntryBlockExecutable(®ion, /*markPessimisticFixpoint=*/true); 275 } 276 } 277 if (callableLatticeState.empty()) 278 return; 279 280 // After computing the valid callables, walk any symbol uses to check 281 // for non-call references. We won't be able to track the lattice state 282 // for arguments to these callables, as we can't guarantee that we can see 283 // all of its calls. 284 Optional<SymbolTable::UseRange> uses = 285 SymbolTable::getSymbolUses(&symbolTableRegion); 286 if (!uses) { 287 // If we couldn't gather the symbol uses, conservatively assume that 288 // we can't track information for any nested symbols. 289 op->walk([&](CallableOpInterface op) { callableLatticeState.erase(op); }); 290 return; 291 } 292 293 for (const SymbolTable::SymbolUse &use : *uses) { 294 // If the use is a call, track it to avoid the need to recompute the 295 // reference later. 296 if (auto callOp = dyn_cast<CallOpInterface>(use.getUser())) { 297 Operation *symCallable = callOp.resolveCallable(&symbolTable); 298 auto callableLatticeIt = callableLatticeState.find(symCallable); 299 if (callableLatticeIt != callableLatticeState.end()) { 300 callToSymbolCallable.try_emplace(callOp, symCallable); 301 302 // We only need to record the call in the lattice if it produces any 303 // values. 304 if (callOp->getNumResults()) 305 callableLatticeIt->second.addSymbolCall(callOp); 306 } 307 continue; 308 } 309 // This use isn't a call, so don't we know all of the callers. 310 auto *symbol = symbolTable.lookupSymbolIn(op, use.getSymbolRef()); 311 auto it = callableLatticeState.find(symbol); 312 if (it != callableLatticeState.end()) { 313 for (Region ®ion : it->first->getRegions()) 314 markEntryBlockExecutable(®ion, /*markPessimisticFixpoint=*/true); 315 } 316 } 317 }; 318 SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(), 319 walkFn); 320 } 321 322 void ForwardDataFlowSolver::visitOperation(Operation *op) { 323 // Collect all of the lattice elements feeding into this operation. If any are 324 // not yet resolved, bail out and wait for them to resolve. 325 SmallVector<AbstractLatticeElement *, 8> operandLattices; 326 operandLattices.reserve(op->getNumOperands()); 327 for (Value operand : op->getOperands()) { 328 AbstractLatticeElement *operandLattice = 329 analysis.lookupLatticeElement(operand); 330 if (!operandLattice || operandLattice->isUninitialized()) 331 return; 332 operandLattices.push_back(operandLattice); 333 } 334 335 // If this is a terminator operation, process any control flow lattice state. 336 if (op->hasTrait<OpTrait::IsTerminator>()) 337 visitTerminatorOperation(op, operandLattices); 338 339 // Process call operations. The call visitor processes result values, so we 340 // can exit afterwards. 341 if (CallOpInterface call = dyn_cast<CallOpInterface>(op)) 342 return visitCallOperation(call); 343 344 // Process callable operations. These are specially handled region operations 345 // that track dataflow via calls. 346 if (isa<CallableOpInterface>(op)) { 347 // If this callable has a tracked lattice state, it will be visited by calls 348 // that reference it instead. This way, we don't assume that it is 349 // executable unless there is a proper reference to it. 350 if (callableLatticeState.count(op)) 351 return; 352 return visitCallableOperation(op); 353 } 354 355 // Process region holding operations. 356 if (op->getNumRegions()) { 357 // Check to see if we can reason about the internal control flow of this 358 // region operation. 359 if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) 360 return visitRegionBranchOperation(branch, operandLattices); 361 362 // If we can't, conservatively mark all regions as executable. 363 // TODO: Let the `visitOperation` method decide how to propagate 364 // information to the block arguments. 365 for (Region ®ion : op->getRegions()) 366 markEntryBlockExecutable(®ion, /*markPessimisticFixpoint=*/true); 367 } 368 369 // If this op produces no results, it can't produce any constants. 370 if (op->getNumResults() == 0) 371 return; 372 373 // If all of the results of this operation are already resolved, bail out 374 // early. 375 auto isAtFixpointFn = [&](Value value) { return isAtFixpoint(value); }; 376 if (llvm::all_of(op->getResults(), isAtFixpointFn)) 377 return; 378 379 // Visit the current operation. 380 if (analysis.visitOperation(op, operandLattices) == ChangeResult::Change) 381 opWorklist.push(op); 382 383 // `visitOperation` is required to define all of the result lattices. 384 assert(llvm::none_of( 385 op->getResults(), 386 [&](Value value) { 387 return analysis.getLatticeElement(value).isUninitialized(); 388 }) && 389 "expected `visitOperation` to define all result lattices"); 390 } 391 392 void ForwardDataFlowSolver::visitCallableOperation(Operation *op) { 393 // Mark the regions as executable. If we aren't tracking lattice state for 394 // this callable, mark all of the region arguments as having reached a 395 // fixpoint. 396 bool isTrackingLatticeState = callableLatticeState.count(op); 397 for (Region ®ion : op->getRegions()) 398 markEntryBlockExecutable(®ion, !isTrackingLatticeState); 399 400 // TODO: Add support for non-symbol callables when necessary. If the callable 401 // has non-call uses we would mark as having reached pessimistic fixpoint, 402 // otherwise allow for propagating the return values out. 403 markAllPessimisticFixpoint(op, op->getResults()); 404 } 405 406 void ForwardDataFlowSolver::visitCallOperation(CallOpInterface op) { 407 ResultRange callResults = op->getResults(); 408 409 // Resolve the callable operation for this call. 410 Operation *callableOp = nullptr; 411 if (Value callableValue = op.getCallableForCallee().dyn_cast<Value>()) 412 callableOp = callableValue.getDefiningOp(); 413 else 414 callableOp = callToSymbolCallable.lookup(op); 415 416 // The callable of this call can't be resolved, mark any results overdefined. 417 if (!callableOp) 418 return markAllPessimisticFixpoint(op, callResults); 419 420 // If this callable is tracking state, merge the argument operands with the 421 // arguments of the callable. 422 auto callableLatticeIt = callableLatticeState.find(callableOp); 423 if (callableLatticeIt == callableLatticeState.end()) 424 return markAllPessimisticFixpoint(op, callResults); 425 426 OperandRange callOperands = op.getArgOperands(); 427 auto callableArgs = callableLatticeIt->second.getCallableArguments(); 428 for (auto it : llvm::zip(callOperands, callableArgs)) { 429 BlockArgument callableArg = std::get<1>(it); 430 AbstractLatticeElement &argValue = analysis.getLatticeElement(callableArg); 431 AbstractLatticeElement &operandValue = 432 analysis.getLatticeElement(std::get<0>(it)); 433 if (argValue.join(operandValue) == ChangeResult::Change) 434 visitUsers(callableArg); 435 } 436 437 // Visit the callable. 438 visitCallableOperation(callableOp); 439 440 // Merge in the lattice state for the callable results as well. 441 auto callableResults = callableLatticeIt->second.getResultLatticeElements(); 442 for (auto it : llvm::zip(callResults, callableResults)) 443 join(/*owner=*/op, 444 /*to=*/analysis.getLatticeElement(std::get<0>(it)), 445 /*from=*/std::get<1>(it)); 446 } 447 448 void ForwardDataFlowSolver::visitRegionBranchOperation( 449 RegionBranchOpInterface branch, 450 ArrayRef<AbstractLatticeElement *> operandLattices) { 451 // Check to see which regions are executable. 452 SmallVector<RegionSuccessor, 1> successors; 453 analysis.getSuccessorsForOperands(branch, /*sourceIndex=*/llvm::None, 454 operandLattices, successors); 455 456 // If the interface identified that no region will be executed. Mark 457 // any results of this operation as overdefined, as we can't reason about 458 // them. 459 // TODO: If we had an interface to detect pass through operands, we could 460 // resolve some results based on the lattice state of the operands. We could 461 // also allow for the parent operation to have itself as a region successor. 462 if (successors.empty()) 463 return markAllPessimisticFixpoint(branch, branch->getResults()); 464 return visitRegionSuccessors( 465 branch, successors, operandLattices, [&](Optional<unsigned> index) { 466 assert(index && "expected valid region index"); 467 return branch.getSuccessorEntryOperands(*index); 468 }); 469 } 470 471 void ForwardDataFlowSolver::visitRegionSuccessors( 472 Operation *parentOp, ArrayRef<RegionSuccessor> regionSuccessors, 473 ArrayRef<AbstractLatticeElement *> operandLattices, 474 function_ref<OperandRange(Optional<unsigned>)> getInputsForRegion) { 475 for (const RegionSuccessor &it : regionSuccessors) { 476 Region *region = it.getSuccessor(); 477 ValueRange succArgs = it.getSuccessorInputs(); 478 479 // Check to see if this is the parent operation. 480 if (!region) { 481 ResultRange results = parentOp->getResults(); 482 if (llvm::all_of(results, [&](Value res) { return isAtFixpoint(res); })) 483 continue; 484 485 // Mark the results outside of the input range as having reached the 486 // pessimistic fixpoint. 487 // TODO: This isn't exactly ideal. There may be situations in which a 488 // region operation can provide information for certain results that 489 // aren't part of the control flow. 490 if (succArgs.size() != results.size()) { 491 opWorklist.push(parentOp); 492 if (succArgs.empty()) { 493 markAllPessimisticFixpoint(results); 494 continue; 495 } 496 497 unsigned firstResIdx = succArgs[0].cast<OpResult>().getResultNumber(); 498 markAllPessimisticFixpoint(results.take_front(firstResIdx)); 499 markAllPessimisticFixpoint( 500 results.drop_front(firstResIdx + succArgs.size())); 501 } 502 503 // Update the lattice for any operation results. 504 OperandRange operands = getInputsForRegion(/*index=*/llvm::None); 505 for (auto it : llvm::zip(succArgs, operands)) 506 join(parentOp, analysis.getLatticeElement(std::get<0>(it)), 507 analysis.getLatticeElement(std::get<1>(it))); 508 continue; 509 } 510 assert(!region->empty() && "expected region to be non-empty"); 511 Block *entryBlock = ®ion->front(); 512 markBlockExecutable(entryBlock); 513 514 // If all of the arguments have already reached a fixpoint, the arguments 515 // have already been fully resolved. 516 Block::BlockArgListType arguments = entryBlock->getArguments(); 517 if (llvm::all_of(arguments, [&](Value arg) { return isAtFixpoint(arg); })) 518 continue; 519 520 if (succArgs.size() != arguments.size()) { 521 if (analysis.visitNonControlFlowArguments( 522 parentOp, it, operandLattices) == ChangeResult::Change) { 523 unsigned firstArgIdx = 524 succArgs.empty() ? 0 525 : succArgs[0].cast<BlockArgument>().getArgNumber(); 526 for (Value v : arguments.take_front(firstArgIdx)) { 527 assert(!analysis.getLatticeElement(v).isUninitialized() && 528 "Non-control flow block arg has no lattice value after " 529 "analysis callback"); 530 visitUsers(v); 531 } 532 for (Value v : arguments.drop_front(firstArgIdx + succArgs.size())) { 533 assert(!analysis.getLatticeElement(v).isUninitialized() && 534 "Non-control flow block arg has no lattice value after " 535 "analysis callback"); 536 visitUsers(v); 537 } 538 } 539 } 540 541 // Update the lattice of arguments that have inputs from the predecessor. 542 OperandRange succOperands = getInputsForRegion(region->getRegionNumber()); 543 for (auto it : llvm::zip(succArgs, succOperands)) { 544 AbstractLatticeElement &argValue = 545 analysis.getLatticeElement(std::get<0>(it)); 546 AbstractLatticeElement &operandValue = 547 analysis.getLatticeElement(std::get<1>(it)); 548 if (argValue.join(operandValue) == ChangeResult::Change) 549 visitUsers(std::get<0>(it)); 550 } 551 } 552 } 553 554 void ForwardDataFlowSolver::visitTerminatorOperation( 555 Operation *op, ArrayRef<AbstractLatticeElement *> operandLattices) { 556 // If this operation has no successors, we treat it as an exiting terminator. 557 if (op->getNumSuccessors() == 0) { 558 Region *parentRegion = op->getParentRegion(); 559 Operation *parentOp = parentRegion->getParentOp(); 560 561 // Check to see if this is a terminator for a callable region. 562 if (isa<CallableOpInterface>(parentOp)) 563 return visitCallableTerminatorOperation(parentOp, op, operandLattices); 564 565 // Otherwise, check to see if the parent tracks region control flow. 566 auto regionInterface = dyn_cast<RegionBranchOpInterface>(parentOp); 567 if (!regionInterface || !isBlockExecutable(parentOp->getBlock())) 568 return; 569 570 // Query the set of successors of the current region using the current 571 // optimistic lattice state. 572 SmallVector<RegionSuccessor, 1> regionSuccessors; 573 analysis.getSuccessorsForOperands(regionInterface, 574 parentRegion->getRegionNumber(), 575 operandLattices, regionSuccessors); 576 if (regionSuccessors.empty()) 577 return; 578 579 // Try to get "region-like" successor operands if possible in order to 580 // propagate the operand states to the successors. 581 if (isRegionReturnLike(op)) { 582 auto getOperands = [&](Optional<unsigned> regionIndex) { 583 // Determine the individual region successor operands for the given 584 // region index (if any). 585 return *getRegionBranchSuccessorOperands(op, regionIndex); 586 }; 587 return visitRegionSuccessors(parentOp, regionSuccessors, operandLattices, 588 getOperands); 589 } 590 591 // If this terminator is not "region-like", conservatively mark all of the 592 // successor values as having reached the pessimistic fixpoint. 593 for (auto &it : regionSuccessors) { 594 // If the successor is a region, mark the entry block as executable so 595 // that we visit operations defined within. If the successor is the 596 // parent operation, we simply mark the control flow results as having 597 // reached the pessimistic state. 598 if (Region *region = it.getSuccessor()) 599 markEntryBlockExecutable(region, /*markPessimisticFixpoint=*/true); 600 else 601 markAllPessimisticFixpointAndVisitUsers(it.getSuccessorInputs()); 602 } 603 } 604 605 // Try to resolve to a specific set of successors with the current optimistic 606 // lattice state. 607 Block *block = op->getBlock(); 608 if (auto branch = dyn_cast<BranchOpInterface>(op)) { 609 SmallVector<Block *> successors; 610 if (succeeded(analysis.getSuccessorsForOperands(branch, operandLattices, 611 successors))) { 612 for (Block *succ : successors) 613 markEdgeExecutable(block, succ); 614 return; 615 } 616 } 617 618 // Otherwise, conservatively treat all edges as executable. 619 for (Block *succ : op->getSuccessors()) 620 markEdgeExecutable(block, succ); 621 } 622 623 void ForwardDataFlowSolver::visitCallableTerminatorOperation( 624 Operation *callable, Operation *terminator, 625 ArrayRef<AbstractLatticeElement *> operandLattices) { 626 // If there are no exiting values, we have nothing to track. 627 if (terminator->getNumOperands() == 0) 628 return; 629 630 // If this callable isn't tracking any lattice state there is nothing to do. 631 auto latticeIt = callableLatticeState.find(callable); 632 if (latticeIt == callableLatticeState.end()) 633 return; 634 assert(callable->getNumResults() == 0 && "expected symbol callable"); 635 636 // If this terminator is not "return-like", conservatively mark all of the 637 // call-site results as having reached the pessimistic fixpoint. 638 auto callableResultLattices = latticeIt->second.getResultLatticeElements(); 639 if (!terminator->hasTrait<OpTrait::ReturnLike>()) { 640 for (auto &it : callableResultLattices) 641 it.markPessimisticFixpoint(); 642 for (Operation *call : latticeIt->second.getSymbolCalls()) 643 markAllPessimisticFixpoint(call, call->getResults()); 644 return; 645 } 646 647 // Merge the lattice state for terminator operands into the results. 648 ChangeResult result = ChangeResult::NoChange; 649 for (auto it : llvm::zip(operandLattices, callableResultLattices)) 650 result |= std::get<1>(it).join(*std::get<0>(it)); 651 if (result == ChangeResult::NoChange) 652 return; 653 654 // If any of the result lattices changed, update the callers. 655 for (Operation *call : latticeIt->second.getSymbolCalls()) 656 for (auto it : llvm::zip(call->getResults(), callableResultLattices)) 657 join(call, analysis.getLatticeElement(std::get<0>(it)), std::get<1>(it)); 658 } 659 660 void ForwardDataFlowSolver::visitBlock(Block *block) { 661 // If the block is not the entry block we need to compute the lattice state 662 // for the block arguments. Entry block argument lattices are computed 663 // elsewhere, such as when visiting the parent operation. 664 if (!block->isEntryBlock()) { 665 for (int i : llvm::seq<int>(0, block->getNumArguments())) 666 visitBlockArgument(block, i); 667 } 668 669 // Visit all of the operations within the block. 670 for (Operation &op : *block) 671 visitOperation(&op); 672 } 673 674 void ForwardDataFlowSolver::visitBlockArgument(Block *block, int i) { 675 BlockArgument arg = block->getArgument(i); 676 AbstractLatticeElement &argLattice = analysis.getLatticeElement(arg); 677 if (argLattice.isAtFixpoint()) 678 return; 679 680 ChangeResult updatedLattice = ChangeResult::NoChange; 681 for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) { 682 Block *pred = *it; 683 684 // We only care about this predecessor if it is going to execute. 685 if (!isEdgeExecutable(pred, block)) 686 continue; 687 688 // Try to get the operand forwarded by the predecessor. If we can't reason 689 // about the terminator of the predecessor, mark as having reached a 690 // fixpoint. 691 auto branch = dyn_cast<BranchOpInterface>(pred->getTerminator()); 692 if (!branch) { 693 updatedLattice |= argLattice.markPessimisticFixpoint(); 694 break; 695 } 696 Value operand = branch.getSuccessorOperands(it.getSuccessorIndex())[i]; 697 if (!operand) { 698 updatedLattice |= argLattice.markPessimisticFixpoint(); 699 break; 700 } 701 702 // If the operand hasn't been resolved, it is uninitialized and can merge 703 // with anything. 704 AbstractLatticeElement *operandLattice = 705 analysis.lookupLatticeElement(operand); 706 if (!operandLattice) 707 continue; 708 709 // Otherwise, join the operand lattice into the argument lattice. 710 updatedLattice |= argLattice.join(*operandLattice); 711 if (argLattice.isAtFixpoint()) 712 break; 713 } 714 715 // If the lattice changed, visit users of the argument. 716 if (updatedLattice == ChangeResult::Change) 717 visitUsers(arg); 718 } 719 720 ChangeResult 721 ForwardDataFlowSolver::markEntryBlockExecutable(Region *region, 722 bool markPessimisticFixpoint) { 723 if (!region->empty()) { 724 if (markPessimisticFixpoint) 725 markAllPessimisticFixpoint(region->front().getArguments()); 726 return markBlockExecutable(®ion->front()); 727 } 728 return ChangeResult::NoChange; 729 } 730 731 ChangeResult ForwardDataFlowSolver::markBlockExecutable(Block *block) { 732 bool marked = executableBlocks.insert(block).second; 733 if (marked) 734 blockWorklist.push(block); 735 return marked ? ChangeResult::Change : ChangeResult::NoChange; 736 } 737 738 bool ForwardDataFlowSolver::isBlockExecutable(Block *block) const { 739 return executableBlocks.count(block); 740 } 741 742 void ForwardDataFlowSolver::markEdgeExecutable(Block *from, Block *to) { 743 executableEdges.insert(std::make_pair(from, to)); 744 745 // Mark the destination as executable, and reprocess its arguments if it was 746 // already executable. 747 if (markBlockExecutable(to) == ChangeResult::NoChange) { 748 for (int i : llvm::seq<int>(0, to->getNumArguments())) 749 visitBlockArgument(to, i); 750 } 751 } 752 753 bool ForwardDataFlowSolver::isEdgeExecutable(Block *from, Block *to) const { 754 return executableEdges.count(std::make_pair(from, to)); 755 } 756 757 void ForwardDataFlowSolver::markPessimisticFixpoint(Value value) { 758 analysis.getLatticeElement(value).markPessimisticFixpoint(); 759 } 760 761 bool ForwardDataFlowSolver::isAtFixpoint(Value value) const { 762 if (auto *lattice = analysis.lookupLatticeElement(value)) 763 return lattice->isAtFixpoint(); 764 return false; 765 } 766 767 void ForwardDataFlowSolver::join(Operation *owner, AbstractLatticeElement &to, 768 const AbstractLatticeElement &from) { 769 if (to.join(from) == ChangeResult::Change) 770 opWorklist.push(owner); 771 } 772 773 //===----------------------------------------------------------------------===// 774 // AbstractLatticeElement 775 //===----------------------------------------------------------------------===// 776 777 AbstractLatticeElement::~AbstractLatticeElement() = default; 778 779 //===----------------------------------------------------------------------===// 780 // ForwardDataFlowAnalysisBase 781 //===----------------------------------------------------------------------===// 782 783 ForwardDataFlowAnalysisBase::~ForwardDataFlowAnalysisBase() = default; 784 785 AbstractLatticeElement & 786 ForwardDataFlowAnalysisBase::getLatticeElement(Value value) { 787 AbstractLatticeElement *&latticeValue = latticeValues[value]; 788 if (!latticeValue) 789 latticeValue = createLatticeElement(value); 790 return *latticeValue; 791 } 792 793 AbstractLatticeElement * 794 ForwardDataFlowAnalysisBase::lookupLatticeElement(Value value) { 795 return latticeValues.lookup(value); 796 } 797 798 void ForwardDataFlowAnalysisBase::run(Operation *topLevelOp) { 799 // Run the main dataflow solver. 800 ForwardDataFlowSolver solver(*this, topLevelOp); 801 solver.solve(); 802 803 // Any values that are still uninitialized now go to a pessimistic fixpoint, 804 // otherwise we assume an optimistic fixpoint has been reached. 805 for (auto &it : latticeValues) 806 if (it.second->isUninitialized()) 807 it.second->markPessimisticFixpoint(); 808 else 809 it.second->markOptimisticFixpoint(); 810 } 811