1 //===- DataFlowAnalysis.cpp -----------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Analysis/DataFlowAnalysis.h" 10 #include "mlir/IR/Operation.h" 11 #include "mlir/Interfaces/CallInterfaces.h" 12 #include "mlir/Interfaces/ControlFlowInterfaces.h" 13 #include "llvm/ADT/SmallPtrSet.h" 14 15 #include <queue> 16 17 using namespace mlir; 18 using namespace mlir::detail; 19 20 namespace { 21 /// This class contains various state used when computing the lattice elements 22 /// of a callable operation. 23 class CallableLatticeState { 24 public: 25 /// Build a lattice state with a given callable region, and a specified number 26 /// of results to be initialized to the default lattice element. 27 CallableLatticeState(ForwardDataFlowAnalysisBase &analysis, 28 Region *callableRegion, unsigned numResults) 29 : callableArguments(callableRegion->getArguments()), 30 resultLatticeElements(numResults) { 31 for (AbstractLatticeElement *&it : resultLatticeElements) 32 it = analysis.createLatticeElement(); 33 } 34 35 /// Returns the arguments to the callable region. 36 Block::BlockArgListType getCallableArguments() const { 37 return callableArguments; 38 } 39 40 /// Returns the lattice element for the results of the callable region. 41 auto getResultLatticeElements() { 42 return llvm::make_pointee_range(resultLatticeElements); 43 } 44 45 /// Add a call to this callable. This is only used if the callable defines a 46 /// symbol. 47 void addSymbolCall(Operation *op) { symbolCalls.push_back(op); } 48 49 /// Return the calls that reference this callable. This is only used 50 /// if the callable defines a symbol. 51 ArrayRef<Operation *> getSymbolCalls() const { return symbolCalls; } 52 53 private: 54 /// The arguments of the callable region. 55 Block::BlockArgListType callableArguments; 56 57 /// The lattice state for each of the results of this region. The return 58 /// values of the callable aren't SSA values, so we need to track them 59 /// separately. 60 SmallVector<AbstractLatticeElement *, 4> resultLatticeElements; 61 62 /// The calls referencing this callable if this callable defines a symbol. 63 /// This removes the need to recompute symbol references during propagation. 64 /// Value based references are trivial to resolve, so they can be done 65 /// in-place. 66 SmallVector<Operation *, 4> symbolCalls; 67 }; 68 69 /// This class represents the solver for a forward dataflow analysis. This class 70 /// acts as the propagation engine for computing which lattice elements. 71 class ForwardDataFlowSolver { 72 public: 73 /// Initialize the solver with the given top-level operation. 74 ForwardDataFlowSolver(ForwardDataFlowAnalysisBase &analysis, Operation *op); 75 76 /// Run the solver until it converges. 77 void solve(); 78 79 private: 80 /// Initialize the set of symbol defining callables that can have their 81 /// arguments and results tracked. 'op' is the top-level operation that the 82 /// solver is operating on. 83 void initializeSymbolCallables(Operation *op); 84 85 /// Visit the users of the given IR that reside within executable blocks. 86 template <typename T> 87 void visitUsers(T &value) { 88 for (Operation *user : value.getUsers()) 89 if (isBlockExecutable(user->getBlock())) 90 visitOperation(user); 91 } 92 93 /// Visit the given operation and compute any necessary lattice state. 94 void visitOperation(Operation *op); 95 96 /// Visit the given call operation and compute any necessary lattice state. 97 void visitCallOperation(CallOpInterface op); 98 99 /// Visit the given callable operation and compute any necessary lattice 100 /// state. 101 void visitCallableOperation(Operation *op); 102 103 /// Visit the given region branch operation, which defines regions, and 104 /// compute any necessary lattice state. This also resolves the lattice state 105 /// of both the operation results and any nested regions. 106 void visitRegionBranchOperation( 107 RegionBranchOpInterface branch, 108 ArrayRef<AbstractLatticeElement *> operandLattices); 109 110 /// Visit the given set of region successors, computing any necessary lattice 111 /// state. The provided function returns the input operands to the region at 112 /// the given index. If the index is 'None', the input operands correspond to 113 /// the parent operation results. 114 void visitRegionSuccessors( 115 Operation *parentOp, ArrayRef<RegionSuccessor> regionSuccessors, 116 function_ref<OperandRange(Optional<unsigned>)> getInputsForRegion); 117 118 /// Visit the given terminator operation and compute any necessary lattice 119 /// state. 120 void 121 visitTerminatorOperation(Operation *op, 122 ArrayRef<AbstractLatticeElement *> operandLattices); 123 124 /// Visit the given terminator operation that exits a callable region. These 125 /// are terminators with no CFG successors. 126 void visitCallableTerminatorOperation( 127 Operation *callable, Operation *terminator, 128 ArrayRef<AbstractLatticeElement *> operandLattices); 129 130 /// Visit the given block and compute any necessary lattice state. 131 void visitBlock(Block *block); 132 133 /// Visit argument #'i' of the given block and compute any necessary lattice 134 /// state. 135 void visitBlockArgument(Block *block, int i); 136 137 /// Mark the entry block of the given region as executable. Returns NoChange 138 /// if the block was already marked executable. If `markPessimisticFixpoint` 139 /// is true, the arguments of the entry block are also marked as having 140 /// reached the pessimistic fixpoint. 141 ChangeResult markEntryBlockExecutable(Region *region, 142 bool markPessimisticFixpoint); 143 144 /// Mark the given block as executable. Returns NoChange if the block was 145 /// already marked executable. 146 ChangeResult markBlockExecutable(Block *block); 147 148 /// Returns true if the given block is executable. 149 bool isBlockExecutable(Block *block) const; 150 151 /// Mark the edge between 'from' and 'to' as executable. 152 void markEdgeExecutable(Block *from, Block *to); 153 154 /// Return true if the edge between 'from' and 'to' is executable. 155 bool isEdgeExecutable(Block *from, Block *to) const; 156 157 /// Mark the given value as having reached the pessimistic fixpoint. This 158 /// means that we cannot further refine the state of this value. 159 void markPessimisticFixpoint(Value value); 160 161 /// Mark all of the given values as having reaching the pessimistic fixpoint. 162 template <typename ValuesT> 163 void markAllPessimisticFixpoint(ValuesT values) { 164 for (auto value : values) 165 markPessimisticFixpoint(value); 166 } 167 template <typename ValuesT> 168 void markAllPessimisticFixpoint(Operation *op, ValuesT values) { 169 markAllPessimisticFixpoint(values); 170 opWorklist.push(op); 171 } 172 template <typename ValuesT> 173 void markAllPessimisticFixpointAndVisitUsers(ValuesT values) { 174 for (auto value : values) { 175 AbstractLatticeElement &lattice = analysis.getLatticeElement(value); 176 if (lattice.markPessimisticFixpoint() == ChangeResult::Change) 177 visitUsers(value); 178 } 179 } 180 181 /// Returns true if the given value was marked as having reached the 182 /// pessimistic fixpoint. 183 bool isAtFixpoint(Value value) const; 184 185 /// Merge in the given lattice 'from' into the lattice 'to'. 'owner' 186 /// corresponds to the parent operation of the lattice for 'to'. 187 void join(Operation *owner, AbstractLatticeElement &to, 188 const AbstractLatticeElement &from); 189 190 /// A reference to the dataflow analysis being computed. 191 ForwardDataFlowAnalysisBase &analysis; 192 193 /// The set of blocks that are known to execute, or are intrinsically live. 194 SmallPtrSet<Block *, 16> executableBlocks; 195 196 /// The set of control flow edges that are known to execute. 197 DenseSet<std::pair<Block *, Block *>> executableEdges; 198 199 /// A worklist containing blocks that need to be processed. 200 std::queue<Block *> blockWorklist; 201 202 /// A worklist of operations that need to be processed. 203 std::queue<Operation *> opWorklist; 204 205 /// The callable operations that have their argument/result state tracked. 206 DenseMap<Operation *, CallableLatticeState> callableLatticeState; 207 208 /// A map between a call operation and the resolved symbol callable. This 209 /// avoids re-resolving symbol references during propagation. Value based 210 /// callables are trivial to resolve, so they can be done in-place. 211 DenseMap<Operation *, Operation *> callToSymbolCallable; 212 213 /// A symbol table used for O(1) symbol lookups during simplification. 214 SymbolTableCollection symbolTable; 215 }; 216 } // namespace 217 218 ForwardDataFlowSolver::ForwardDataFlowSolver( 219 ForwardDataFlowAnalysisBase &analysis, Operation *op) 220 : analysis(analysis) { 221 /// Initialize the solver with the regions within this operation. 222 for (Region ®ion : op->getRegions()) { 223 // Mark the entry block as executable. The values passed to these regions 224 // are also invisible, so mark any arguments as reaching the pessimistic 225 // fixpoint. 226 markEntryBlockExecutable(®ion, /*markPessimisticFixpoint=*/true); 227 } 228 initializeSymbolCallables(op); 229 } 230 231 void ForwardDataFlowSolver::solve() { 232 while (!blockWorklist.empty() || !opWorklist.empty()) { 233 // Process any operations in the op worklist. 234 while (!opWorklist.empty()) { 235 Operation *nextOp = opWorklist.front(); 236 opWorklist.pop(); 237 visitUsers(*nextOp); 238 } 239 240 // Process any blocks in the block worklist. 241 while (!blockWorklist.empty()) { 242 Block *nextBlock = blockWorklist.front(); 243 blockWorklist.pop(); 244 visitBlock(nextBlock); 245 } 246 } 247 } 248 249 void ForwardDataFlowSolver::initializeSymbolCallables(Operation *op) { 250 // Initialize the set of symbol callables that can have their state tracked. 251 // This tracks which symbol callable operations we can propagate within and 252 // out of. 253 auto walkFn = [&](Operation *symTable, bool allUsesVisible) { 254 Region &symbolTableRegion = symTable->getRegion(0); 255 Block *symbolTableBlock = &symbolTableRegion.front(); 256 for (auto callable : symbolTableBlock->getOps<CallableOpInterface>()) { 257 // We won't be able to track external callables. 258 Region *callableRegion = callable.getCallableRegion(); 259 if (!callableRegion) 260 continue; 261 // We only care about symbol defining callables here. 262 auto symbol = dyn_cast<SymbolOpInterface>(callable.getOperation()); 263 if (!symbol) 264 continue; 265 callableLatticeState.try_emplace(callable, analysis, callableRegion, 266 callable.getCallableResults().size()); 267 268 // If not all of the uses of this symbol are visible, we can't track the 269 // state of the arguments. 270 if (symbol.isPublic() || (!allUsesVisible && symbol.isNested())) { 271 for (Region ®ion : callable->getRegions()) 272 markEntryBlockExecutable(®ion, /*markPessimisticFixpoint=*/true); 273 } 274 } 275 if (callableLatticeState.empty()) 276 return; 277 278 // After computing the valid callables, walk any symbol uses to check 279 // for non-call references. We won't be able to track the lattice state 280 // for arguments to these callables, as we can't guarantee that we can see 281 // all of its calls. 282 Optional<SymbolTable::UseRange> uses = 283 SymbolTable::getSymbolUses(&symbolTableRegion); 284 if (!uses) { 285 // If we couldn't gather the symbol uses, conservatively assume that 286 // we can't track information for any nested symbols. 287 op->walk([&](CallableOpInterface op) { callableLatticeState.erase(op); }); 288 return; 289 } 290 291 for (const SymbolTable::SymbolUse &use : *uses) { 292 // If the use is a call, track it to avoid the need to recompute the 293 // reference later. 294 if (auto callOp = dyn_cast<CallOpInterface>(use.getUser())) { 295 Operation *symCallable = callOp.resolveCallable(&symbolTable); 296 auto callableLatticeIt = callableLatticeState.find(symCallable); 297 if (callableLatticeIt != callableLatticeState.end()) { 298 callToSymbolCallable.try_emplace(callOp, symCallable); 299 300 // We only need to record the call in the lattice if it produces any 301 // values. 302 if (callOp->getNumResults()) 303 callableLatticeIt->second.addSymbolCall(callOp); 304 } 305 continue; 306 } 307 // This use isn't a call, so don't we know all of the callers. 308 auto *symbol = symbolTable.lookupSymbolIn(op, use.getSymbolRef()); 309 auto it = callableLatticeState.find(symbol); 310 if (it != callableLatticeState.end()) { 311 for (Region ®ion : it->first->getRegions()) 312 markEntryBlockExecutable(®ion, /*markPessimisticFixpoint=*/true); 313 } 314 } 315 }; 316 SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(), 317 walkFn); 318 } 319 320 void ForwardDataFlowSolver::visitOperation(Operation *op) { 321 // Collect all of the lattice elements feeding into this operation. If any are 322 // not yet resolved, bail out and wait for them to resolve. 323 SmallVector<AbstractLatticeElement *, 8> operandLattices; 324 operandLattices.reserve(op->getNumOperands()); 325 for (Value operand : op->getOperands()) { 326 AbstractLatticeElement *operandLattice = 327 analysis.lookupLatticeElement(operand); 328 if (!operandLattice || operandLattice->isUninitialized()) 329 return; 330 operandLattices.push_back(operandLattice); 331 } 332 333 // If this is a terminator operation, process any control flow lattice state. 334 if (op->hasTrait<OpTrait::IsTerminator>()) 335 visitTerminatorOperation(op, operandLattices); 336 337 // Process call operations. The call visitor processes result values, so we 338 // can exit afterwards. 339 if (CallOpInterface call = dyn_cast<CallOpInterface>(op)) 340 return visitCallOperation(call); 341 342 // Process callable operations. These are specially handled region operations 343 // that track dataflow via calls. 344 if (isa<CallableOpInterface>(op)) { 345 // If this callable has a tracked lattice state, it will be visited by calls 346 // that reference it instead. This way, we don't assume that it is 347 // executable unless there is a proper reference to it. 348 if (callableLatticeState.count(op)) 349 return; 350 return visitCallableOperation(op); 351 } 352 353 // Process region holding operations. 354 if (op->getNumRegions()) { 355 // Check to see if we can reason about the internal control flow of this 356 // region operation. 357 if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) 358 return visitRegionBranchOperation(branch, operandLattices); 359 360 // If we can't, conservatively mark all regions as executable. 361 // TODO: Let the `visitOperation` method decide how to propagate 362 // information to the block arguments. 363 for (Region ®ion : op->getRegions()) 364 markEntryBlockExecutable(®ion, /*markPessimisticFixpoint=*/true); 365 } 366 367 // If this op produces no results, it can't produce any constants. 368 if (op->getNumResults() == 0) 369 return; 370 371 // If all of the results of this operation are already resolved, bail out 372 // early. 373 auto isAtFixpointFn = [&](Value value) { return isAtFixpoint(value); }; 374 if (llvm::all_of(op->getResults(), isAtFixpointFn)) 375 return; 376 377 // Visit the current operation. 378 if (analysis.visitOperation(op, operandLattices) == ChangeResult::Change) 379 opWorklist.push(op); 380 381 // `visitOperation` is required to define all of the result lattices. 382 assert(llvm::none_of( 383 op->getResults(), 384 [&](Value value) { 385 return analysis.getLatticeElement(value).isUninitialized(); 386 }) && 387 "expected `visitOperation` to define all result lattices"); 388 } 389 390 void ForwardDataFlowSolver::visitCallableOperation(Operation *op) { 391 // Mark the regions as executable. If we aren't tracking lattice state for 392 // this callable, mark all of the region arguments as having reached a 393 // fixpoint. 394 bool isTrackingLatticeState = callableLatticeState.count(op); 395 for (Region ®ion : op->getRegions()) 396 markEntryBlockExecutable(®ion, !isTrackingLatticeState); 397 398 // TODO: Add support for non-symbol callables when necessary. If the callable 399 // has non-call uses we would mark as having reached pessimistic fixpoint, 400 // otherwise allow for propagating the return values out. 401 markAllPessimisticFixpoint(op, op->getResults()); 402 } 403 404 void ForwardDataFlowSolver::visitCallOperation(CallOpInterface op) { 405 ResultRange callResults = op->getResults(); 406 407 // Resolve the callable operation for this call. 408 Operation *callableOp = nullptr; 409 if (Value callableValue = op.getCallableForCallee().dyn_cast<Value>()) 410 callableOp = callableValue.getDefiningOp(); 411 else 412 callableOp = callToSymbolCallable.lookup(op); 413 414 // The callable of this call can't be resolved, mark any results overdefined. 415 if (!callableOp) 416 return markAllPessimisticFixpoint(op, callResults); 417 418 // If this callable is tracking state, merge the argument operands with the 419 // arguments of the callable. 420 auto callableLatticeIt = callableLatticeState.find(callableOp); 421 if (callableLatticeIt == callableLatticeState.end()) 422 return markAllPessimisticFixpoint(op, callResults); 423 424 OperandRange callOperands = op.getArgOperands(); 425 auto callableArgs = callableLatticeIt->second.getCallableArguments(); 426 for (auto it : llvm::zip(callOperands, callableArgs)) { 427 BlockArgument callableArg = std::get<1>(it); 428 AbstractLatticeElement &argValue = analysis.getLatticeElement(callableArg); 429 AbstractLatticeElement &operandValue = 430 analysis.getLatticeElement(std::get<0>(it)); 431 if (argValue.join(operandValue) == ChangeResult::Change) 432 visitUsers(callableArg); 433 } 434 435 // Visit the callable. 436 visitCallableOperation(callableOp); 437 438 // Merge in the lattice state for the callable results as well. 439 auto callableResults = callableLatticeIt->second.getResultLatticeElements(); 440 for (auto it : llvm::zip(callResults, callableResults)) 441 join(/*owner=*/op, 442 /*to=*/analysis.getLatticeElement(std::get<0>(it)), 443 /*from=*/std::get<1>(it)); 444 } 445 446 void ForwardDataFlowSolver::visitRegionBranchOperation( 447 RegionBranchOpInterface branch, 448 ArrayRef<AbstractLatticeElement *> operandLattices) { 449 // Check to see which regions are executable. 450 SmallVector<RegionSuccessor, 1> successors; 451 analysis.getSuccessorsForOperands(branch, /*sourceIndex=*/llvm::None, 452 operandLattices, successors); 453 454 // If the interface identified that no region will be executed. Mark 455 // any results of this operation as overdefined, as we can't reason about 456 // them. 457 // TODO: If we had an interface to detect pass through operands, we could 458 // resolve some results based on the lattice state of the operands. We could 459 // also allow for the parent operation to have itself as a region successor. 460 if (successors.empty()) 461 return markAllPessimisticFixpoint(branch, branch->getResults()); 462 return visitRegionSuccessors( 463 branch, successors, [&](Optional<unsigned> index) { 464 assert(index && "expected valid region index"); 465 return branch.getSuccessorEntryOperands(*index); 466 }); 467 } 468 469 void ForwardDataFlowSolver::visitRegionSuccessors( 470 Operation *parentOp, ArrayRef<RegionSuccessor> regionSuccessors, 471 function_ref<OperandRange(Optional<unsigned>)> getInputsForRegion) { 472 for (const RegionSuccessor &it : regionSuccessors) { 473 Region *region = it.getSuccessor(); 474 ValueRange succArgs = it.getSuccessorInputs(); 475 476 // Check to see if this is the parent operation. 477 if (!region) { 478 ResultRange results = parentOp->getResults(); 479 if (llvm::all_of(results, [&](Value res) { return isAtFixpoint(res); })) 480 continue; 481 482 // Mark the results outside of the input range as having reached the 483 // pessimistic fixpoint. 484 // TODO: This isn't exactly ideal. There may be situations in which a 485 // region operation can provide information for certain results that 486 // aren't part of the control flow. 487 if (succArgs.size() != results.size()) { 488 opWorklist.push(parentOp); 489 if (succArgs.empty()) { 490 markAllPessimisticFixpoint(results); 491 continue; 492 } 493 494 unsigned firstResIdx = succArgs[0].cast<OpResult>().getResultNumber(); 495 markAllPessimisticFixpoint(results.take_front(firstResIdx)); 496 markAllPessimisticFixpoint( 497 results.drop_front(firstResIdx + succArgs.size())); 498 } 499 500 // Update the lattice for any operation results. 501 OperandRange operands = getInputsForRegion(/*index=*/llvm::None); 502 for (auto it : llvm::zip(succArgs, operands)) 503 join(parentOp, analysis.getLatticeElement(std::get<0>(it)), 504 analysis.getLatticeElement(std::get<1>(it))); 505 continue; 506 } 507 assert(!region->empty() && "expected region to be non-empty"); 508 Block *entryBlock = ®ion->front(); 509 markBlockExecutable(entryBlock); 510 511 // If all of the arguments have already reached a fixpoint, the arguments 512 // have already been fully resolved. 513 Block::BlockArgListType arguments = entryBlock->getArguments(); 514 if (llvm::all_of(arguments, [&](Value arg) { return isAtFixpoint(arg); })) 515 continue; 516 517 // Mark any arguments that do not receive inputs as having reached a 518 // pessimistic fixpoint, we won't be able to discern if they are constant. 519 // TODO: This isn't exactly ideal. There may be situations in which a 520 // region operation can provide information for certain results that 521 // aren't part of the control flow. 522 if (succArgs.size() != arguments.size()) { 523 if (succArgs.empty()) { 524 markAllPessimisticFixpoint(arguments); 525 continue; 526 } 527 528 unsigned firstArgIdx = succArgs[0].cast<BlockArgument>().getArgNumber(); 529 markAllPessimisticFixpointAndVisitUsers( 530 arguments.take_front(firstArgIdx)); 531 markAllPessimisticFixpointAndVisitUsers( 532 arguments.drop_front(firstArgIdx + succArgs.size())); 533 } 534 535 // Update the lattice of arguments that have inputs from the predecessor. 536 OperandRange succOperands = getInputsForRegion(region->getRegionNumber()); 537 for (auto it : llvm::zip(succArgs, succOperands)) { 538 AbstractLatticeElement &argValue = 539 analysis.getLatticeElement(std::get<0>(it)); 540 AbstractLatticeElement &operandValue = 541 analysis.getLatticeElement(std::get<1>(it)); 542 if (argValue.join(operandValue) == ChangeResult::Change) 543 visitUsers(std::get<0>(it)); 544 } 545 } 546 } 547 548 void ForwardDataFlowSolver::visitTerminatorOperation( 549 Operation *op, ArrayRef<AbstractLatticeElement *> operandLattices) { 550 // If this operation has no successors, we treat it as an exiting terminator. 551 if (op->getNumSuccessors() == 0) { 552 Region *parentRegion = op->getParentRegion(); 553 Operation *parentOp = parentRegion->getParentOp(); 554 555 // Check to see if this is a terminator for a callable region. 556 if (isa<CallableOpInterface>(parentOp)) 557 return visitCallableTerminatorOperation(parentOp, op, operandLattices); 558 559 // Otherwise, check to see if the parent tracks region control flow. 560 auto regionInterface = dyn_cast<RegionBranchOpInterface>(parentOp); 561 if (!regionInterface || !isBlockExecutable(parentOp->getBlock())) 562 return; 563 564 // Query the set of successors of the current region using the current 565 // optimistic lattice state. 566 SmallVector<RegionSuccessor, 1> regionSuccessors; 567 analysis.getSuccessorsForOperands(regionInterface, 568 parentRegion->getRegionNumber(), 569 operandLattices, regionSuccessors); 570 if (regionSuccessors.empty()) 571 return; 572 573 // Try to get "region-like" successor operands if possible in order to 574 // propagate the operand states to the successors. 575 if (isRegionReturnLike(op)) { 576 return visitRegionSuccessors( 577 parentOp, regionSuccessors, [&](Optional<unsigned> regionIndex) { 578 // Determine the individual region successor operands for the given 579 // region index (if any). 580 return *getRegionBranchSuccessorOperands(op, regionIndex); 581 }); 582 } 583 584 // If this terminator is not "region-like", conservatively mark all of the 585 // successor values as having reached the pessimistic fixpoint. 586 for (auto &it : regionSuccessors) { 587 // If the successor is a region, mark the entry block as executable so 588 // that we visit operations defined within. If the successor is the 589 // parent operation, we simply mark the control flow results as having 590 // reached the pessimistic state. 591 if (Region *region = it.getSuccessor()) 592 markEntryBlockExecutable(region, /*markPessimisticFixpoint=*/true); 593 else 594 markAllPessimisticFixpointAndVisitUsers(it.getSuccessorInputs()); 595 } 596 } 597 598 // Try to resolve to a specific set of successors with the current optimistic 599 // lattice state. 600 Block *block = op->getBlock(); 601 if (auto branch = dyn_cast<BranchOpInterface>(op)) { 602 SmallVector<Block *> successors; 603 if (succeeded(analysis.getSuccessorsForOperands(branch, operandLattices, 604 successors))) { 605 for (Block *succ : successors) 606 markEdgeExecutable(block, succ); 607 return; 608 } 609 } 610 611 // Otherwise, conservatively treat all edges as executable. 612 for (Block *succ : op->getSuccessors()) 613 markEdgeExecutable(block, succ); 614 } 615 616 void ForwardDataFlowSolver::visitCallableTerminatorOperation( 617 Operation *callable, Operation *terminator, 618 ArrayRef<AbstractLatticeElement *> operandLattices) { 619 // If there are no exiting values, we have nothing to track. 620 if (terminator->getNumOperands() == 0) 621 return; 622 623 // If this callable isn't tracking any lattice state there is nothing to do. 624 auto latticeIt = callableLatticeState.find(callable); 625 if (latticeIt == callableLatticeState.end()) 626 return; 627 assert(callable->getNumResults() == 0 && "expected symbol callable"); 628 629 // If this terminator is not "return-like", conservatively mark all of the 630 // call-site results as having reached the pessimistic fixpoint. 631 auto callableResultLattices = latticeIt->second.getResultLatticeElements(); 632 if (!terminator->hasTrait<OpTrait::ReturnLike>()) { 633 for (auto &it : callableResultLattices) 634 it.markPessimisticFixpoint(); 635 for (Operation *call : latticeIt->second.getSymbolCalls()) 636 markAllPessimisticFixpoint(call, call->getResults()); 637 return; 638 } 639 640 // Merge the lattice state for terminator operands into the results. 641 ChangeResult result = ChangeResult::NoChange; 642 for (auto it : llvm::zip(operandLattices, callableResultLattices)) 643 result |= std::get<1>(it).join(*std::get<0>(it)); 644 if (result == ChangeResult::NoChange) 645 return; 646 647 // If any of the result lattices changed, update the callers. 648 for (Operation *call : latticeIt->second.getSymbolCalls()) 649 for (auto it : llvm::zip(call->getResults(), callableResultLattices)) 650 join(call, analysis.getLatticeElement(std::get<0>(it)), std::get<1>(it)); 651 } 652 653 void ForwardDataFlowSolver::visitBlock(Block *block) { 654 // If the block is not the entry block we need to compute the lattice state 655 // for the block arguments. Entry block argument lattices are computed 656 // elsewhere, such as when visiting the parent operation. 657 if (!block->isEntryBlock()) { 658 for (int i : llvm::seq<int>(0, block->getNumArguments())) 659 visitBlockArgument(block, i); 660 } 661 662 // Visit all of the operations within the block. 663 for (Operation &op : *block) 664 visitOperation(&op); 665 } 666 667 void ForwardDataFlowSolver::visitBlockArgument(Block *block, int i) { 668 BlockArgument arg = block->getArgument(i); 669 AbstractLatticeElement &argLattice = analysis.getLatticeElement(arg); 670 if (argLattice.isAtFixpoint()) 671 return; 672 673 ChangeResult updatedLattice = ChangeResult::NoChange; 674 for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) { 675 Block *pred = *it; 676 677 // We only care about this predecessor if it is going to execute. 678 if (!isEdgeExecutable(pred, block)) 679 continue; 680 681 // Try to get the operand forwarded by the predecessor. If we can't reason 682 // about the terminator of the predecessor, mark as having reached a 683 // fixpoint. 684 auto branch = dyn_cast<BranchOpInterface>(pred->getTerminator()); 685 if (!branch) { 686 updatedLattice |= argLattice.markPessimisticFixpoint(); 687 break; 688 } 689 Value operand = branch.getSuccessorOperands(it.getSuccessorIndex())[i]; 690 if (!operand) { 691 updatedLattice |= argLattice.markPessimisticFixpoint(); 692 break; 693 } 694 695 // If the operand hasn't been resolved, it is uninitialized and can merge 696 // with anything. 697 AbstractLatticeElement *operandLattice = 698 analysis.lookupLatticeElement(operand); 699 if (!operandLattice) 700 continue; 701 702 // Otherwise, join the operand lattice into the argument lattice. 703 updatedLattice |= argLattice.join(*operandLattice); 704 if (argLattice.isAtFixpoint()) 705 break; 706 } 707 708 // If the lattice changed, visit users of the argument. 709 if (updatedLattice == ChangeResult::Change) 710 visitUsers(arg); 711 } 712 713 ChangeResult 714 ForwardDataFlowSolver::markEntryBlockExecutable(Region *region, 715 bool markPessimisticFixpoint) { 716 if (!region->empty()) { 717 if (markPessimisticFixpoint) 718 markAllPessimisticFixpoint(region->front().getArguments()); 719 return markBlockExecutable(®ion->front()); 720 } 721 return ChangeResult::NoChange; 722 } 723 724 ChangeResult ForwardDataFlowSolver::markBlockExecutable(Block *block) { 725 bool marked = executableBlocks.insert(block).second; 726 if (marked) 727 blockWorklist.push(block); 728 return marked ? ChangeResult::Change : ChangeResult::NoChange; 729 } 730 731 bool ForwardDataFlowSolver::isBlockExecutable(Block *block) const { 732 return executableBlocks.count(block); 733 } 734 735 void ForwardDataFlowSolver::markEdgeExecutable(Block *from, Block *to) { 736 executableEdges.insert(std::make_pair(from, to)); 737 738 // Mark the destination as executable, and reprocess its arguments if it was 739 // already executable. 740 if (markBlockExecutable(to) == ChangeResult::NoChange) { 741 for (int i : llvm::seq<int>(0, to->getNumArguments())) 742 visitBlockArgument(to, i); 743 } 744 } 745 746 bool ForwardDataFlowSolver::isEdgeExecutable(Block *from, Block *to) const { 747 return executableEdges.count(std::make_pair(from, to)); 748 } 749 750 void ForwardDataFlowSolver::markPessimisticFixpoint(Value value) { 751 analysis.getLatticeElement(value).markPessimisticFixpoint(); 752 } 753 754 bool ForwardDataFlowSolver::isAtFixpoint(Value value) const { 755 if (auto *lattice = analysis.lookupLatticeElement(value)) 756 return lattice->isAtFixpoint(); 757 return false; 758 } 759 760 void ForwardDataFlowSolver::join(Operation *owner, AbstractLatticeElement &to, 761 const AbstractLatticeElement &from) { 762 if (to.join(from) == ChangeResult::Change) 763 opWorklist.push(owner); 764 } 765 766 //===----------------------------------------------------------------------===// 767 // AbstractLatticeElement 768 //===----------------------------------------------------------------------===// 769 770 AbstractLatticeElement::~AbstractLatticeElement() = default; 771 772 //===----------------------------------------------------------------------===// 773 // ForwardDataFlowAnalysisBase 774 //===----------------------------------------------------------------------===// 775 776 ForwardDataFlowAnalysisBase::~ForwardDataFlowAnalysisBase() = default; 777 778 AbstractLatticeElement & 779 ForwardDataFlowAnalysisBase::getLatticeElement(Value value) { 780 AbstractLatticeElement *&latticeValue = latticeValues[value]; 781 if (!latticeValue) 782 latticeValue = createLatticeElement(value); 783 return *latticeValue; 784 } 785 786 AbstractLatticeElement * 787 ForwardDataFlowAnalysisBase::lookupLatticeElement(Value value) { 788 return latticeValues.lookup(value); 789 } 790 791 void ForwardDataFlowAnalysisBase::run(Operation *topLevelOp) { 792 // Run the main dataflow solver. 793 ForwardDataFlowSolver solver(*this, topLevelOp); 794 solver.solve(); 795 796 // Any values that are still uninitialized now go to a pessimistic fixpoint, 797 // otherwise we assume an optimistic fixpoint has been reached. 798 for (auto &it : latticeValues) 799 if (it.second->isUninitialized()) 800 it.second->markPessimisticFixpoint(); 801 else 802 it.second->markOptimisticFixpoint(); 803 } 804