1 //===- DataFlowAnalysis.cpp -----------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Analysis/DataFlowAnalysis.h" 10 #include "mlir/IR/Operation.h" 11 #include "mlir/Interfaces/CallInterfaces.h" 12 #include "mlir/Interfaces/ControlFlowInterfaces.h" 13 #include "llvm/ADT/SmallPtrSet.h" 14 15 using namespace mlir; 16 using namespace mlir::detail; 17 18 namespace { 19 /// This class contains various state used when computing the lattice elements 20 /// of a callable operation. 21 class CallableLatticeState { 22 public: 23 /// Build a lattice state with a given callable region, and a specified number 24 /// of results to be initialized to the default lattice element. 25 CallableLatticeState(ForwardDataFlowAnalysisBase &analysis, 26 Region *callableRegion, unsigned numResults) 27 : callableArguments(callableRegion->getArguments()), 28 resultLatticeElements(numResults) { 29 for (AbstractLatticeElement *&it : resultLatticeElements) 30 it = analysis.createLatticeElement(); 31 } 32 33 /// Returns the arguments to the callable region. 34 Block::BlockArgListType getCallableArguments() const { 35 return callableArguments; 36 } 37 38 /// Returns the lattice element for the results of the callable region. 39 auto getResultLatticeElements() { 40 return llvm::make_pointee_range(resultLatticeElements); 41 } 42 43 /// Add a call to this callable. This is only used if the callable defines a 44 /// symbol. 45 void addSymbolCall(Operation *op) { symbolCalls.push_back(op); } 46 47 /// Return the calls that reference this callable. This is only used 48 /// if the callable defines a symbol. 49 ArrayRef<Operation *> getSymbolCalls() const { return symbolCalls; } 50 51 private: 52 /// The arguments of the callable region. 53 Block::BlockArgListType callableArguments; 54 55 /// The lattice state for each of the results of this region. The return 56 /// values of the callable aren't SSA values, so we need to track them 57 /// separately. 58 SmallVector<AbstractLatticeElement *, 4> resultLatticeElements; 59 60 /// The calls referencing this callable if this callable defines a symbol. 61 /// This removes the need to recompute symbol references during propagation. 62 /// Value based references are trivial to resolve, so they can be done 63 /// in-place. 64 SmallVector<Operation *, 4> symbolCalls; 65 }; 66 67 /// This class represents the solver for a forward dataflow analysis. This class 68 /// acts as the propagation engine for computing which lattice elements. 69 class ForwardDataFlowSolver { 70 public: 71 /// Initialize the solver with the given top-level operation. 72 ForwardDataFlowSolver(ForwardDataFlowAnalysisBase &analysis, Operation *op); 73 74 /// Run the solver until it converges. 75 void solve(); 76 77 private: 78 /// Initialize the set of symbol defining callables that can have their 79 /// arguments and results tracked. 'op' is the top-level operation that the 80 /// solver is operating on. 81 void initializeSymbolCallables(Operation *op); 82 83 /// Visit the users of the given IR that reside within executable blocks. 84 template <typename T> 85 void visitUsers(T &value) { 86 for (Operation *user : value.getUsers()) 87 if (isBlockExecutable(user->getBlock())) 88 visitOperation(user); 89 } 90 91 /// Visit the given operation and compute any necessary lattice state. 92 void visitOperation(Operation *op); 93 94 /// Visit the given call operation and compute any necessary lattice state. 95 void visitCallOperation(CallOpInterface op); 96 97 /// Visit the given callable operation and compute any necessary lattice 98 /// state. 99 void visitCallableOperation(Operation *op); 100 101 /// Visit the given region branch operation, which defines regions, and 102 /// compute any necessary lattice state. This also resolves the lattice state 103 /// of both the operation results and any nested regions. 104 void visitRegionBranchOperation( 105 RegionBranchOpInterface branch, 106 ArrayRef<AbstractLatticeElement *> operandLattices); 107 108 /// Visit the given set of region successors, computing any necessary lattice 109 /// state. The provided function returns the input operands to the region at 110 /// the given index. If the index is 'None', the input operands correspond to 111 /// the parent operation results. 112 void visitRegionSuccessors( 113 Operation *parentOp, ArrayRef<RegionSuccessor> regionSuccessors, 114 function_ref<OperandRange(Optional<unsigned>)> getInputsForRegion); 115 116 /// Visit the given terminator operation and compute any necessary lattice 117 /// state. 118 void 119 visitTerminatorOperation(Operation *op, 120 ArrayRef<AbstractLatticeElement *> operandLattices); 121 122 /// Visit the given terminator operation that exits a callable region. These 123 /// are terminators with no CFG successors. 124 void visitCallableTerminatorOperation( 125 Operation *callable, Operation *terminator, 126 ArrayRef<AbstractLatticeElement *> operandLattices); 127 128 /// Visit the given block and compute any necessary lattice state. 129 void visitBlock(Block *block); 130 131 /// Visit argument #'i' of the given block and compute any necessary lattice 132 /// state. 133 void visitBlockArgument(Block *block, int i); 134 135 /// Mark the entry block of the given region as executable. Returns NoChange 136 /// if the block was already marked executable. If `markPessimisticFixpoint` 137 /// is true, the arguments of the entry block are also marked as having 138 /// reached the pessimistic fixpoint. 139 ChangeResult markEntryBlockExecutable(Region *region, 140 bool markPessimisticFixpoint); 141 142 /// Mark the given block as executable. Returns NoChange if the block was 143 /// already marked executable. 144 ChangeResult markBlockExecutable(Block *block); 145 146 /// Returns true if the given block is executable. 147 bool isBlockExecutable(Block *block) const; 148 149 /// Mark the edge between 'from' and 'to' as executable. 150 void markEdgeExecutable(Block *from, Block *to); 151 152 /// Return true if the edge between 'from' and 'to' is executable. 153 bool isEdgeExecutable(Block *from, Block *to) const; 154 155 /// Mark the given value as having reached the pessimistic fixpoint. This 156 /// means that we cannot further refine the state of this value. 157 void markPessimisticFixpoint(Value value); 158 159 /// Mark all of the given values as having reaching the pessimistic fixpoint. 160 template <typename ValuesT> 161 void markAllPessimisticFixpoint(ValuesT values) { 162 for (auto value : values) 163 markPessimisticFixpoint(value); 164 } 165 template <typename ValuesT> 166 void markAllPessimisticFixpoint(Operation *op, ValuesT values) { 167 markAllPessimisticFixpoint(values); 168 opWorklist.push_back(op); 169 } 170 template <typename ValuesT> 171 void markAllPessimisticFixpointAndVisitUsers(ValuesT values) { 172 for (auto value : values) { 173 AbstractLatticeElement &lattice = analysis.getLatticeElement(value); 174 if (lattice.markPessimisticFixpoint() == ChangeResult::Change) 175 visitUsers(value); 176 } 177 } 178 179 /// Returns true if the given value was marked as having reached the 180 /// pessimistic fixpoint. 181 bool isAtFixpoint(Value value) const; 182 183 /// Merge in the given lattice 'from' into the lattice 'to'. 'owner' 184 /// corresponds to the parent operation of the lattice for 'to'. 185 void join(Operation *owner, AbstractLatticeElement &to, 186 const AbstractLatticeElement &from); 187 188 /// A reference to the dataflow analysis being computed. 189 ForwardDataFlowAnalysisBase &analysis; 190 191 /// The set of blocks that are known to execute, or are intrinsically live. 192 SmallPtrSet<Block *, 16> executableBlocks; 193 194 /// The set of control flow edges that are known to execute. 195 DenseSet<std::pair<Block *, Block *>> executableEdges; 196 197 /// A worklist containing blocks that need to be processed. 198 SmallVector<Block *, 64> blockWorklist; 199 200 /// A worklist of operations that need to be processed. 201 SmallVector<Operation *, 64> opWorklist; 202 203 /// The callable operations that have their argument/result state tracked. 204 DenseMap<Operation *, CallableLatticeState> callableLatticeState; 205 206 /// A map between a call operation and the resolved symbol callable. This 207 /// avoids re-resolving symbol references during propagation. Value based 208 /// callables are trivial to resolve, so they can be done in-place. 209 DenseMap<Operation *, Operation *> callToSymbolCallable; 210 211 /// A symbol table used for O(1) symbol lookups during simplification. 212 SymbolTableCollection symbolTable; 213 }; 214 } // namespace 215 216 ForwardDataFlowSolver::ForwardDataFlowSolver( 217 ForwardDataFlowAnalysisBase &analysis, Operation *op) 218 : analysis(analysis) { 219 /// Initialize the solver with the regions within this operation. 220 for (Region ®ion : op->getRegions()) { 221 // Mark the entry block as executable. The values passed to these regions 222 // are also invisible, so mark any arguments as reaching the pessimistic 223 // fixpoint. 224 markEntryBlockExecutable(®ion, /*markPessimisticFixpoint=*/true); 225 } 226 initializeSymbolCallables(op); 227 } 228 229 void ForwardDataFlowSolver::solve() { 230 while (!blockWorklist.empty() || !opWorklist.empty()) { 231 // Process any operations in the op worklist. 232 while (!opWorklist.empty()) 233 visitUsers(*opWorklist.pop_back_val()); 234 235 // Process any blocks in the block worklist. 236 while (!blockWorklist.empty()) 237 visitBlock(blockWorklist.pop_back_val()); 238 } 239 } 240 241 void ForwardDataFlowSolver::initializeSymbolCallables(Operation *op) { 242 // Initialize the set of symbol callables that can have their state tracked. 243 // This tracks which symbol callable operations we can propagate within and 244 // out of. 245 auto walkFn = [&](Operation *symTable, bool allUsesVisible) { 246 Region &symbolTableRegion = symTable->getRegion(0); 247 Block *symbolTableBlock = &symbolTableRegion.front(); 248 for (auto callable : symbolTableBlock->getOps<CallableOpInterface>()) { 249 // We won't be able to track external callables. 250 Region *callableRegion = callable.getCallableRegion(); 251 if (!callableRegion) 252 continue; 253 // We only care about symbol defining callables here. 254 auto symbol = dyn_cast<SymbolOpInterface>(callable.getOperation()); 255 if (!symbol) 256 continue; 257 callableLatticeState.try_emplace(callable, analysis, callableRegion, 258 callable.getCallableResults().size()); 259 260 // If not all of the uses of this symbol are visible, we can't track the 261 // state of the arguments. 262 if (symbol.isPublic() || (!allUsesVisible && symbol.isNested())) { 263 for (Region ®ion : callable->getRegions()) 264 markEntryBlockExecutable(®ion, /*markPessimisticFixpoint=*/true); 265 } 266 } 267 if (callableLatticeState.empty()) 268 return; 269 270 // After computing the valid callables, walk any symbol uses to check 271 // for non-call references. We won't be able to track the lattice state 272 // for arguments to these callables, as we can't guarantee that we can see 273 // all of its calls. 274 Optional<SymbolTable::UseRange> uses = 275 SymbolTable::getSymbolUses(&symbolTableRegion); 276 if (!uses) { 277 // If we couldn't gather the symbol uses, conservatively assume that 278 // we can't track information for any nested symbols. 279 op->walk([&](CallableOpInterface op) { callableLatticeState.erase(op); }); 280 return; 281 } 282 283 for (const SymbolTable::SymbolUse &use : *uses) { 284 // If the use is a call, track it to avoid the need to recompute the 285 // reference later. 286 if (auto callOp = dyn_cast<CallOpInterface>(use.getUser())) { 287 Operation *symCallable = callOp.resolveCallable(&symbolTable); 288 auto callableLatticeIt = callableLatticeState.find(symCallable); 289 if (callableLatticeIt != callableLatticeState.end()) { 290 callToSymbolCallable.try_emplace(callOp, symCallable); 291 292 // We only need to record the call in the lattice if it produces any 293 // values. 294 if (callOp->getNumResults()) 295 callableLatticeIt->second.addSymbolCall(callOp); 296 } 297 continue; 298 } 299 // This use isn't a call, so don't we know all of the callers. 300 auto *symbol = symbolTable.lookupSymbolIn(op, use.getSymbolRef()); 301 auto it = callableLatticeState.find(symbol); 302 if (it != callableLatticeState.end()) { 303 for (Region ®ion : it->first->getRegions()) 304 markEntryBlockExecutable(®ion, /*markPessimisticFixpoint=*/true); 305 } 306 } 307 }; 308 SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(), 309 walkFn); 310 } 311 312 void ForwardDataFlowSolver::visitOperation(Operation *op) { 313 // Collect all of the lattice elements feeding into this operation. If any are 314 // not yet resolved, bail out and wait for them to resolve. 315 SmallVector<AbstractLatticeElement *, 8> operandLattices; 316 operandLattices.reserve(op->getNumOperands()); 317 for (Value operand : op->getOperands()) { 318 AbstractLatticeElement *operandLattice = 319 analysis.lookupLatticeElement(operand); 320 if (!operandLattice || operandLattice->isUninitialized()) 321 return; 322 operandLattices.push_back(operandLattice); 323 } 324 325 // If this is a terminator operation, process any control flow lattice state. 326 if (op->hasTrait<OpTrait::IsTerminator>()) 327 visitTerminatorOperation(op, operandLattices); 328 329 // Process call operations. The call visitor processes result values, so we 330 // can exit afterwards. 331 if (CallOpInterface call = dyn_cast<CallOpInterface>(op)) 332 return visitCallOperation(call); 333 334 // Process callable operations. These are specially handled region operations 335 // that track dataflow via calls. 336 if (isa<CallableOpInterface>(op)) { 337 // If this callable has a tracked lattice state, it will be visited by calls 338 // that reference it instead. This way, we don't assume that it is 339 // executable unless there is a proper reference to it. 340 if (callableLatticeState.count(op)) 341 return; 342 return visitCallableOperation(op); 343 } 344 345 // Process region holding operations. 346 if (op->getNumRegions()) { 347 // Check to see if we can reason about the internal control flow of this 348 // region operation. 349 if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) 350 return visitRegionBranchOperation(branch, operandLattices); 351 352 // If we can't, conservatively mark all regions as executable. 353 // TODO: Let the `visitOperation` method decide how to propagate 354 // information to the block arguments. 355 for (Region ®ion : op->getRegions()) 356 markEntryBlockExecutable(®ion, /*markPessimisticFixpoint=*/true); 357 } 358 359 // If this op produces no results, it can't produce any constants. 360 if (op->getNumResults() == 0) 361 return; 362 363 // If all of the results of this operation are already resolved, bail out 364 // early. 365 auto isAtFixpointFn = [&](Value value) { return isAtFixpoint(value); }; 366 if (llvm::all_of(op->getResults(), isAtFixpointFn)) 367 return; 368 369 // Visit the current operation. 370 if (analysis.visitOperation(op, operandLattices) == ChangeResult::Change) 371 opWorklist.push_back(op); 372 373 // `visitOperation` is required to define all of the result lattices. 374 assert(llvm::none_of( 375 op->getResults(), 376 [&](Value value) { 377 return analysis.getLatticeElement(value).isUninitialized(); 378 }) && 379 "expected `visitOperation` to define all result lattices"); 380 } 381 382 void ForwardDataFlowSolver::visitCallableOperation(Operation *op) { 383 // Mark the regions as executable. If we aren't tracking lattice state for 384 // this callable, mark all of the region arguments as having reached a 385 // fixpoint. 386 bool isTrackingLatticeState = callableLatticeState.count(op); 387 for (Region ®ion : op->getRegions()) 388 markEntryBlockExecutable(®ion, !isTrackingLatticeState); 389 390 // TODO: Add support for non-symbol callables when necessary. If the callable 391 // has non-call uses we would mark as having reached pessimistic fixpoint, 392 // otherwise allow for propagating the return values out. 393 markAllPessimisticFixpoint(op, op->getResults()); 394 } 395 396 void ForwardDataFlowSolver::visitCallOperation(CallOpInterface op) { 397 ResultRange callResults = op->getResults(); 398 399 // Resolve the callable operation for this call. 400 Operation *callableOp = nullptr; 401 if (Value callableValue = op.getCallableForCallee().dyn_cast<Value>()) 402 callableOp = callableValue.getDefiningOp(); 403 else 404 callableOp = callToSymbolCallable.lookup(op); 405 406 // The callable of this call can't be resolved, mark any results overdefined. 407 if (!callableOp) 408 return markAllPessimisticFixpoint(op, callResults); 409 410 // If this callable is tracking state, merge the argument operands with the 411 // arguments of the callable. 412 auto callableLatticeIt = callableLatticeState.find(callableOp); 413 if (callableLatticeIt == callableLatticeState.end()) 414 return markAllPessimisticFixpoint(op, callResults); 415 416 OperandRange callOperands = op.getArgOperands(); 417 auto callableArgs = callableLatticeIt->second.getCallableArguments(); 418 for (auto it : llvm::zip(callOperands, callableArgs)) { 419 BlockArgument callableArg = std::get<1>(it); 420 AbstractLatticeElement &argValue = analysis.getLatticeElement(callableArg); 421 AbstractLatticeElement &operandValue = 422 analysis.getLatticeElement(std::get<0>(it)); 423 if (argValue.join(operandValue) == ChangeResult::Change) 424 visitUsers(callableArg); 425 } 426 427 // Visit the callable. 428 visitCallableOperation(callableOp); 429 430 // Merge in the lattice state for the callable results as well. 431 auto callableResults = callableLatticeIt->second.getResultLatticeElements(); 432 for (auto it : llvm::zip(callResults, callableResults)) 433 join(/*owner=*/op, 434 /*to=*/analysis.getLatticeElement(std::get<0>(it)), 435 /*from=*/std::get<1>(it)); 436 } 437 438 void ForwardDataFlowSolver::visitRegionBranchOperation( 439 RegionBranchOpInterface branch, 440 ArrayRef<AbstractLatticeElement *> operandLattices) { 441 // Check to see which regions are executable. 442 SmallVector<RegionSuccessor, 1> successors; 443 analysis.getSuccessorsForOperands(branch, /*sourceIndex=*/llvm::None, 444 operandLattices, successors); 445 446 // If the interface identified that no region will be executed. Mark 447 // any results of this operation as overdefined, as we can't reason about 448 // them. 449 // TODO: If we had an interface to detect pass through operands, we could 450 // resolve some results based on the lattice state of the operands. We could 451 // also allow for the parent operation to have itself as a region successor. 452 if (successors.empty()) 453 return markAllPessimisticFixpoint(branch, branch->getResults()); 454 return visitRegionSuccessors( 455 branch, successors, [&](Optional<unsigned> index) { 456 assert(index && "expected valid region index"); 457 return branch.getSuccessorEntryOperands(*index); 458 }); 459 } 460 461 void ForwardDataFlowSolver::visitRegionSuccessors( 462 Operation *parentOp, ArrayRef<RegionSuccessor> regionSuccessors, 463 function_ref<OperandRange(Optional<unsigned>)> getInputsForRegion) { 464 for (const RegionSuccessor &it : regionSuccessors) { 465 Region *region = it.getSuccessor(); 466 ValueRange succArgs = it.getSuccessorInputs(); 467 468 // Check to see if this is the parent operation. 469 if (!region) { 470 ResultRange results = parentOp->getResults(); 471 if (llvm::all_of(results, [&](Value res) { return isAtFixpoint(res); })) 472 continue; 473 474 // Mark the results outside of the input range as having reached the 475 // pessimistic fixpoint. 476 // TODO: This isn't exactly ideal. There may be situations in which a 477 // region operation can provide information for certain results that 478 // aren't part of the control flow. 479 if (succArgs.size() != results.size()) { 480 opWorklist.push_back(parentOp); 481 if (succArgs.empty()) { 482 markAllPessimisticFixpoint(results); 483 continue; 484 } 485 486 unsigned firstResIdx = succArgs[0].cast<OpResult>().getResultNumber(); 487 markAllPessimisticFixpoint(results.take_front(firstResIdx)); 488 markAllPessimisticFixpoint( 489 results.drop_front(firstResIdx + succArgs.size())); 490 } 491 492 // Update the lattice for any operation results. 493 OperandRange operands = getInputsForRegion(/*index=*/llvm::None); 494 for (auto it : llvm::zip(succArgs, operands)) 495 join(parentOp, analysis.getLatticeElement(std::get<0>(it)), 496 analysis.getLatticeElement(std::get<1>(it))); 497 continue; 498 } 499 assert(!region->empty() && "expected region to be non-empty"); 500 Block *entryBlock = ®ion->front(); 501 markBlockExecutable(entryBlock); 502 503 // If all of the arguments have already reached a fixpoint, the arguments 504 // have already been fully resolved. 505 Block::BlockArgListType arguments = entryBlock->getArguments(); 506 if (llvm::all_of(arguments, [&](Value arg) { return isAtFixpoint(arg); })) 507 continue; 508 509 // Mark any arguments that do not receive inputs as having reached a 510 // pessimistic fixpoint, we won't be able to discern if they are constant. 511 // TODO: This isn't exactly ideal. There may be situations in which a 512 // region operation can provide information for certain results that 513 // aren't part of the control flow. 514 if (succArgs.size() != arguments.size()) { 515 if (succArgs.empty()) { 516 markAllPessimisticFixpoint(arguments); 517 continue; 518 } 519 520 unsigned firstArgIdx = succArgs[0].cast<BlockArgument>().getArgNumber(); 521 markAllPessimisticFixpointAndVisitUsers( 522 arguments.take_front(firstArgIdx)); 523 markAllPessimisticFixpointAndVisitUsers( 524 arguments.drop_front(firstArgIdx + succArgs.size())); 525 } 526 527 // Update the lattice of arguments that have inputs from the predecessor. 528 OperandRange succOperands = getInputsForRegion(region->getRegionNumber()); 529 for (auto it : llvm::zip(succArgs, succOperands)) { 530 AbstractLatticeElement &argValue = 531 analysis.getLatticeElement(std::get<0>(it)); 532 AbstractLatticeElement &operandValue = 533 analysis.getLatticeElement(std::get<1>(it)); 534 if (argValue.join(operandValue) == ChangeResult::Change) 535 visitUsers(std::get<0>(it)); 536 } 537 } 538 } 539 540 void ForwardDataFlowSolver::visitTerminatorOperation( 541 Operation *op, ArrayRef<AbstractLatticeElement *> operandLattices) { 542 // If this operation has no successors, we treat it as an exiting terminator. 543 if (op->getNumSuccessors() == 0) { 544 Region *parentRegion = op->getParentRegion(); 545 Operation *parentOp = parentRegion->getParentOp(); 546 547 // Check to see if this is a terminator for a callable region. 548 if (isa<CallableOpInterface>(parentOp)) 549 return visitCallableTerminatorOperation(parentOp, op, operandLattices); 550 551 // Otherwise, check to see if the parent tracks region control flow. 552 auto regionInterface = dyn_cast<RegionBranchOpInterface>(parentOp); 553 if (!regionInterface || !isBlockExecutable(parentOp->getBlock())) 554 return; 555 556 // Query the set of successors of the current region using the current 557 // optimistic lattice state. 558 SmallVector<RegionSuccessor, 1> regionSuccessors; 559 analysis.getSuccessorsForOperands(regionInterface, 560 parentRegion->getRegionNumber(), 561 operandLattices, regionSuccessors); 562 if (regionSuccessors.empty()) 563 return; 564 565 // Try to get "region-like" successor operands if possible in order to 566 // propagate the operand states to the successors. 567 if (isRegionReturnLike(op)) { 568 return visitRegionSuccessors( 569 parentOp, regionSuccessors, [&](Optional<unsigned> regionIndex) { 570 // Determine the individual region successor operands for the given 571 // region index (if any). 572 return *getRegionBranchSuccessorOperands(op, regionIndex); 573 }); 574 } 575 576 // If this terminator is not "region-like", conservatively mark all of the 577 // successor values as having reached the pessimistic fixpoint. 578 for (auto &it : regionSuccessors) { 579 // If the successor is a region, mark the entry block as executable so 580 // that we visit operations defined within. If the successor is the 581 // parent operation, we simply mark the control flow results as having 582 // reached the pessimistic state. 583 if (Region *region = it.getSuccessor()) 584 markEntryBlockExecutable(region, /*markPessimisticFixpoint=*/true); 585 else 586 markAllPessimisticFixpointAndVisitUsers(it.getSuccessorInputs()); 587 } 588 } 589 590 // Try to resolve to a specific set of successors with the current optimistic 591 // lattice state. 592 Block *block = op->getBlock(); 593 if (auto branch = dyn_cast<BranchOpInterface>(op)) { 594 SmallVector<Block *> successors; 595 if (succeeded(analysis.getSuccessorsForOperands(branch, operandLattices, 596 successors))) { 597 for (Block *succ : successors) 598 markEdgeExecutable(block, succ); 599 return; 600 } 601 } 602 603 // Otherwise, conservatively treat all edges as executable. 604 for (Block *succ : op->getSuccessors()) 605 markEdgeExecutable(block, succ); 606 } 607 608 void ForwardDataFlowSolver::visitCallableTerminatorOperation( 609 Operation *callable, Operation *terminator, 610 ArrayRef<AbstractLatticeElement *> operandLattices) { 611 // If there are no exiting values, we have nothing to track. 612 if (terminator->getNumOperands() == 0) 613 return; 614 615 // If this callable isn't tracking any lattice state there is nothing to do. 616 auto latticeIt = callableLatticeState.find(callable); 617 if (latticeIt == callableLatticeState.end()) 618 return; 619 assert(callable->getNumResults() == 0 && "expected symbol callable"); 620 621 // If this terminator is not "return-like", conservatively mark all of the 622 // call-site results as having reached the pessimistic fixpoint. 623 auto callableResultLattices = latticeIt->second.getResultLatticeElements(); 624 if (!terminator->hasTrait<OpTrait::ReturnLike>()) { 625 for (auto &it : callableResultLattices) 626 it.markPessimisticFixpoint(); 627 for (Operation *call : latticeIt->second.getSymbolCalls()) 628 markAllPessimisticFixpoint(call, call->getResults()); 629 return; 630 } 631 632 // Merge the lattice state for terminator operands into the results. 633 ChangeResult result = ChangeResult::NoChange; 634 for (auto it : llvm::zip(operandLattices, callableResultLattices)) 635 result |= std::get<1>(it).join(*std::get<0>(it)); 636 if (result == ChangeResult::NoChange) 637 return; 638 639 // If any of the result lattices changed, update the callers. 640 for (Operation *call : latticeIt->second.getSymbolCalls()) 641 for (auto it : llvm::zip(call->getResults(), callableResultLattices)) 642 join(call, analysis.getLatticeElement(std::get<0>(it)), std::get<1>(it)); 643 } 644 645 void ForwardDataFlowSolver::visitBlock(Block *block) { 646 // If the block is not the entry block we need to compute the lattice state 647 // for the block arguments. Entry block argument lattices are computed 648 // elsewhere, such as when visiting the parent operation. 649 if (!block->isEntryBlock()) { 650 for (int i : llvm::seq<int>(0, block->getNumArguments())) 651 visitBlockArgument(block, i); 652 } 653 654 // Visit all of the operations within the block. 655 for (Operation &op : *block) 656 visitOperation(&op); 657 } 658 659 void ForwardDataFlowSolver::visitBlockArgument(Block *block, int i) { 660 BlockArgument arg = block->getArgument(i); 661 AbstractLatticeElement &argLattice = analysis.getLatticeElement(arg); 662 if (argLattice.isAtFixpoint()) 663 return; 664 665 ChangeResult updatedLattice = ChangeResult::NoChange; 666 for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) { 667 Block *pred = *it; 668 669 // We only care about this predecessor if it is going to execute. 670 if (!isEdgeExecutable(pred, block)) 671 continue; 672 673 // Try to get the operand forwarded by the predecessor. If we can't reason 674 // about the terminator of the predecessor, mark as having reached a 675 // fixpoint. 676 Optional<OperandRange> branchOperands; 677 if (auto branch = dyn_cast<BranchOpInterface>(pred->getTerminator())) 678 branchOperands = branch.getSuccessorOperands(it.getSuccessorIndex()); 679 if (!branchOperands) { 680 updatedLattice |= argLattice.markPessimisticFixpoint(); 681 break; 682 } 683 684 // If the operand hasn't been resolved, it is uninitialized and can merge 685 // with anything. 686 AbstractLatticeElement *operandLattice = 687 analysis.lookupLatticeElement((*branchOperands)[i]); 688 if (!operandLattice) 689 continue; 690 691 // Otherwise, join the operand lattice into the argument lattice. 692 updatedLattice |= argLattice.join(*operandLattice); 693 if (argLattice.isAtFixpoint()) 694 break; 695 } 696 697 // If the lattice changed, visit users of the argument. 698 if (updatedLattice == ChangeResult::Change) 699 visitUsers(arg); 700 } 701 702 ChangeResult 703 ForwardDataFlowSolver::markEntryBlockExecutable(Region *region, 704 bool markPessimisticFixpoint) { 705 if (!region->empty()) { 706 if (markPessimisticFixpoint) 707 markAllPessimisticFixpoint(region->front().getArguments()); 708 return markBlockExecutable(®ion->front()); 709 } 710 return ChangeResult::NoChange; 711 } 712 713 ChangeResult ForwardDataFlowSolver::markBlockExecutable(Block *block) { 714 bool marked = executableBlocks.insert(block).second; 715 if (marked) 716 blockWorklist.push_back(block); 717 return marked ? ChangeResult::Change : ChangeResult::NoChange; 718 } 719 720 bool ForwardDataFlowSolver::isBlockExecutable(Block *block) const { 721 return executableBlocks.count(block); 722 } 723 724 void ForwardDataFlowSolver::markEdgeExecutable(Block *from, Block *to) { 725 executableEdges.insert(std::make_pair(from, to)); 726 727 // Mark the destination as executable, and reprocess its arguments if it was 728 // already executable. 729 if (markBlockExecutable(to) == ChangeResult::NoChange) { 730 for (int i : llvm::seq<int>(0, to->getNumArguments())) 731 visitBlockArgument(to, i); 732 } 733 } 734 735 bool ForwardDataFlowSolver::isEdgeExecutable(Block *from, Block *to) const { 736 return executableEdges.count(std::make_pair(from, to)); 737 } 738 739 void ForwardDataFlowSolver::markPessimisticFixpoint(Value value) { 740 analysis.getLatticeElement(value).markPessimisticFixpoint(); 741 } 742 743 bool ForwardDataFlowSolver::isAtFixpoint(Value value) const { 744 if (auto *lattice = analysis.lookupLatticeElement(value)) 745 return lattice->isAtFixpoint(); 746 return false; 747 } 748 749 void ForwardDataFlowSolver::join(Operation *owner, AbstractLatticeElement &to, 750 const AbstractLatticeElement &from) { 751 if (to.join(from) == ChangeResult::Change) 752 opWorklist.push_back(owner); 753 } 754 755 //===----------------------------------------------------------------------===// 756 // AbstractLatticeElement 757 //===----------------------------------------------------------------------===// 758 759 AbstractLatticeElement::~AbstractLatticeElement() {} 760 761 //===----------------------------------------------------------------------===// 762 // ForwardDataFlowAnalysisBase 763 //===----------------------------------------------------------------------===// 764 765 ForwardDataFlowAnalysisBase::~ForwardDataFlowAnalysisBase() {} 766 767 AbstractLatticeElement & 768 ForwardDataFlowAnalysisBase::getLatticeElement(Value value) { 769 AbstractLatticeElement *&latticeValue = latticeValues[value]; 770 if (!latticeValue) 771 latticeValue = createLatticeElement(value); 772 return *latticeValue; 773 } 774 775 AbstractLatticeElement * 776 ForwardDataFlowAnalysisBase::lookupLatticeElement(Value value) { 777 return latticeValues.lookup(value); 778 } 779 780 void ForwardDataFlowAnalysisBase::run(Operation *topLevelOp) { 781 // Run the main dataflow solver. 782 ForwardDataFlowSolver solver(*this, topLevelOp); 783 solver.solve(); 784 785 // Any values that are still uninitialized now go to a pessimistic fixpoint, 786 // otherwise we assume an optimistic fixpoint has been reached. 787 for (auto &it : latticeValues) 788 if (it.second->isUninitialized()) 789 it.second->markPessimisticFixpoint(); 790 else 791 it.second->markOptimisticFixpoint(); 792 } 793