1c095afcbSMogball //===- SparseAnalysis.cpp - Sparse data-flow analysis ---------------------===//
2c095afcbSMogball //
3c095afcbSMogball // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4c095afcbSMogball // See https://llvm.org/LICENSE.txt for license information.
5c095afcbSMogball // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6c095afcbSMogball //
7c095afcbSMogball //===----------------------------------------------------------------------===//
8c095afcbSMogball 
9c095afcbSMogball #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
109432fbfeSMogball #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
119432fbfeSMogball #include "mlir/Interfaces/CallInterfaces.h"
12c095afcbSMogball 
13c095afcbSMogball using namespace mlir;
14c095afcbSMogball using namespace mlir::dataflow;
15c095afcbSMogball 
16c095afcbSMogball //===----------------------------------------------------------------------===//
17c095afcbSMogball // AbstractSparseLattice
18c095afcbSMogball //===----------------------------------------------------------------------===//
19c095afcbSMogball 
onUpdate(DataFlowSolver * solver) const20c095afcbSMogball void AbstractSparseLattice::onUpdate(DataFlowSolver *solver) const {
21c095afcbSMogball   // Push all users of the value to the queue.
22c095afcbSMogball   for (Operation *user : point.get<Value>().getUsers())
23c095afcbSMogball     for (DataFlowAnalysis *analysis : useDefSubscribers)
24c095afcbSMogball       solver->enqueue({user, analysis});
25c095afcbSMogball }
269432fbfeSMogball 
279432fbfeSMogball //===----------------------------------------------------------------------===//
289432fbfeSMogball // AbstractSparseDataFlowAnalysis
299432fbfeSMogball //===----------------------------------------------------------------------===//
309432fbfeSMogball 
AbstractSparseDataFlowAnalysis(DataFlowSolver & solver)319432fbfeSMogball AbstractSparseDataFlowAnalysis::AbstractSparseDataFlowAnalysis(
329432fbfeSMogball     DataFlowSolver &solver)
339432fbfeSMogball     : DataFlowAnalysis(solver) {
349432fbfeSMogball   registerPointKind<CFGEdge>();
359432fbfeSMogball }
369432fbfeSMogball 
initialize(Operation * top)379432fbfeSMogball LogicalResult AbstractSparseDataFlowAnalysis::initialize(Operation *top) {
389432fbfeSMogball   // Mark the entry block arguments as having reached their pessimistic
399432fbfeSMogball   // fixpoints.
409432fbfeSMogball   for (Region &region : top->getRegions()) {
419432fbfeSMogball     if (region.empty())
429432fbfeSMogball       continue;
439432fbfeSMogball     for (Value argument : region.front().getArguments())
449432fbfeSMogball       markAllPessimisticFixpoint(getLatticeElement(argument));
459432fbfeSMogball   }
469432fbfeSMogball 
479432fbfeSMogball   return initializeRecursively(top);
489432fbfeSMogball }
499432fbfeSMogball 
509432fbfeSMogball LogicalResult
initializeRecursively(Operation * op)519432fbfeSMogball AbstractSparseDataFlowAnalysis::initializeRecursively(Operation *op) {
529432fbfeSMogball   // Initialize the analysis by visiting every owner of an SSA value (all
539432fbfeSMogball   // operations and blocks).
549432fbfeSMogball   visitOperation(op);
559432fbfeSMogball   for (Region &region : op->getRegions()) {
569432fbfeSMogball     for (Block &block : region) {
579432fbfeSMogball       getOrCreate<Executable>(&block)->blockContentSubscribe(this);
589432fbfeSMogball       visitBlock(&block);
599432fbfeSMogball       for (Operation &op : block)
609432fbfeSMogball         if (failed(initializeRecursively(&op)))
619432fbfeSMogball           return failure();
629432fbfeSMogball     }
639432fbfeSMogball   }
649432fbfeSMogball 
659432fbfeSMogball   return success();
669432fbfeSMogball }
679432fbfeSMogball 
visit(ProgramPoint point)689432fbfeSMogball LogicalResult AbstractSparseDataFlowAnalysis::visit(ProgramPoint point) {
699432fbfeSMogball   if (Operation *op = point.dyn_cast<Operation *>())
709432fbfeSMogball     visitOperation(op);
719432fbfeSMogball   else if (Block *block = point.dyn_cast<Block *>())
729432fbfeSMogball     visitBlock(block);
739432fbfeSMogball   else
749432fbfeSMogball     return failure();
759432fbfeSMogball   return success();
769432fbfeSMogball }
779432fbfeSMogball 
visitOperation(Operation * op)789432fbfeSMogball void AbstractSparseDataFlowAnalysis::visitOperation(Operation *op) {
799432fbfeSMogball   // Exit early on operations with no results.
809432fbfeSMogball   if (op->getNumResults() == 0)
819432fbfeSMogball     return;
829432fbfeSMogball 
839432fbfeSMogball   // If the containing block is not executable, bail out.
849432fbfeSMogball   if (!getOrCreate<Executable>(op->getBlock())->isLive())
859432fbfeSMogball     return;
869432fbfeSMogball 
879432fbfeSMogball   // Get the result lattices.
889432fbfeSMogball   SmallVector<AbstractSparseLattice *> resultLattices;
899432fbfeSMogball   resultLattices.reserve(op->getNumResults());
909432fbfeSMogball   // Track whether all results have reached their fixpoint.
919432fbfeSMogball   bool allAtFixpoint = true;
929432fbfeSMogball   for (Value result : op->getResults()) {
939432fbfeSMogball     AbstractSparseLattice *resultLattice = getLatticeElement(result);
949432fbfeSMogball     allAtFixpoint &= resultLattice->isAtFixpoint();
959432fbfeSMogball     resultLattices.push_back(resultLattice);
969432fbfeSMogball   }
979432fbfeSMogball   // If all result lattices have reached a fixpoint, there is nothing to do.
989432fbfeSMogball   if (allAtFixpoint)
999432fbfeSMogball     return;
1009432fbfeSMogball 
1019432fbfeSMogball   // The results of a region branch operation are determined by control-flow.
1029432fbfeSMogball   if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
1039432fbfeSMogball     return visitRegionSuccessors({branch}, branch,
1049432fbfeSMogball                                  /*successorIndex=*/llvm::None, resultLattices);
1059432fbfeSMogball   }
1069432fbfeSMogball 
1079432fbfeSMogball   // The results of a call operation are determined by the callgraph.
1089432fbfeSMogball   if (auto call = dyn_cast<CallOpInterface>(op)) {
1099432fbfeSMogball     const auto *predecessors = getOrCreateFor<PredecessorState>(op, call);
1109432fbfeSMogball     // If not all return sites are known, then conservatively assume we can't
1119432fbfeSMogball     // reason about the data-flow.
1129432fbfeSMogball     if (!predecessors->allPredecessorsKnown())
1139432fbfeSMogball       return markAllPessimisticFixpoint(resultLattices);
1149432fbfeSMogball     for (Operation *predecessor : predecessors->getKnownPredecessors())
1159432fbfeSMogball       for (auto it : llvm::zip(predecessor->getOperands(), resultLattices))
1169432fbfeSMogball         join(std::get<1>(it), *getLatticeElementFor(op, std::get<0>(it)));
1179432fbfeSMogball     return;
1189432fbfeSMogball   }
1199432fbfeSMogball 
1209432fbfeSMogball   // Grab the lattice elements of the operands.
1219432fbfeSMogball   SmallVector<const AbstractSparseLattice *> operandLattices;
1229432fbfeSMogball   operandLattices.reserve(op->getNumOperands());
1239432fbfeSMogball   for (Value operand : op->getOperands()) {
1249432fbfeSMogball     AbstractSparseLattice *operandLattice = getLatticeElement(operand);
1259432fbfeSMogball     operandLattice->useDefSubscribe(this);
1269432fbfeSMogball     // If any of the operand states are not initialized, bail out.
1279432fbfeSMogball     if (operandLattice->isUninitialized())
1289432fbfeSMogball       return;
1299432fbfeSMogball     operandLattices.push_back(operandLattice);
1309432fbfeSMogball   }
1319432fbfeSMogball 
1329432fbfeSMogball   // Invoke the operation transfer function.
1339432fbfeSMogball   visitOperationImpl(op, operandLattices, resultLattices);
1349432fbfeSMogball }
1359432fbfeSMogball 
visitBlock(Block * block)1369432fbfeSMogball void AbstractSparseDataFlowAnalysis::visitBlock(Block *block) {
1379432fbfeSMogball   // Exit early on blocks with no arguments.
1389432fbfeSMogball   if (block->getNumArguments() == 0)
1399432fbfeSMogball     return;
1409432fbfeSMogball 
1419432fbfeSMogball   // If the block is not executable, bail out.
1429432fbfeSMogball   if (!getOrCreate<Executable>(block)->isLive())
1439432fbfeSMogball     return;
1449432fbfeSMogball 
1459432fbfeSMogball   // Get the argument lattices.
1469432fbfeSMogball   SmallVector<AbstractSparseLattice *> argLattices;
1479432fbfeSMogball   argLattices.reserve(block->getNumArguments());
1489432fbfeSMogball   bool allAtFixpoint = true;
1499432fbfeSMogball   for (BlockArgument argument : block->getArguments()) {
1509432fbfeSMogball     AbstractSparseLattice *argLattice = getLatticeElement(argument);
1519432fbfeSMogball     allAtFixpoint &= argLattice->isAtFixpoint();
1529432fbfeSMogball     argLattices.push_back(argLattice);
1539432fbfeSMogball   }
1549432fbfeSMogball   // If all argument lattices have reached their fixpoints, then there is
1559432fbfeSMogball   // nothing to do.
1569432fbfeSMogball   if (allAtFixpoint)
1579432fbfeSMogball     return;
1589432fbfeSMogball 
1599432fbfeSMogball   // The argument lattices of entry blocks are set by region control-flow or the
1609432fbfeSMogball   // callgraph.
1619432fbfeSMogball   if (block->isEntryBlock()) {
1629432fbfeSMogball     // Check if this block is the entry block of a callable region.
1639432fbfeSMogball     auto callable = dyn_cast<CallableOpInterface>(block->getParentOp());
1649432fbfeSMogball     if (callable && callable.getCallableRegion() == block->getParent()) {
1659432fbfeSMogball       const auto *callsites = getOrCreateFor<PredecessorState>(block, callable);
1669432fbfeSMogball       // If not all callsites are known, conservatively mark all lattices as
1679432fbfeSMogball       // having reached their pessimistic fixpoints.
1689432fbfeSMogball       if (!callsites->allPredecessorsKnown())
1699432fbfeSMogball         return markAllPessimisticFixpoint(argLattices);
1709432fbfeSMogball       for (Operation *callsite : callsites->getKnownPredecessors()) {
1719432fbfeSMogball         auto call = cast<CallOpInterface>(callsite);
1729432fbfeSMogball         for (auto it : llvm::zip(call.getArgOperands(), argLattices))
1739432fbfeSMogball           join(std::get<1>(it), *getLatticeElementFor(block, std::get<0>(it)));
1749432fbfeSMogball       }
1759432fbfeSMogball       return;
1769432fbfeSMogball     }
1779432fbfeSMogball 
1789432fbfeSMogball     // Check if the lattices can be determined from region control flow.
1799432fbfeSMogball     if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
1809432fbfeSMogball       return visitRegionSuccessors(
1819432fbfeSMogball           block, branch, block->getParent()->getRegionNumber(), argLattices);
1829432fbfeSMogball     }
1839432fbfeSMogball 
1849432fbfeSMogball     // Otherwise, we can't reason about the data-flow.
185*ab701975SMogball     return visitNonControlFlowArgumentsImpl(block->getParentOp(),
186*ab701975SMogball                                             RegionSuccessor(block->getParent()),
187*ab701975SMogball                                             argLattices, /*firstIndex=*/0);
1889432fbfeSMogball   }
1899432fbfeSMogball 
1909432fbfeSMogball   // Iterate over the predecessors of the non-entry block.
1919432fbfeSMogball   for (Block::pred_iterator it = block->pred_begin(), e = block->pred_end();
1929432fbfeSMogball        it != e; ++it) {
1939432fbfeSMogball     Block *predecessor = *it;
1949432fbfeSMogball 
1959432fbfeSMogball     // If the edge from the predecessor block to the current block is not live,
1969432fbfeSMogball     // bail out.
1979432fbfeSMogball     auto *edgeExecutable =
1989432fbfeSMogball         getOrCreate<Executable>(getProgramPoint<CFGEdge>(predecessor, block));
1999432fbfeSMogball     edgeExecutable->blockContentSubscribe(this);
2009432fbfeSMogball     if (!edgeExecutable->isLive())
2019432fbfeSMogball       continue;
2029432fbfeSMogball 
2039432fbfeSMogball     // Check if we can reason about the data-flow from the predecessor.
2049432fbfeSMogball     if (auto branch =
2059432fbfeSMogball             dyn_cast<BranchOpInterface>(predecessor->getTerminator())) {
2069432fbfeSMogball       SuccessorOperands operands =
2079432fbfeSMogball           branch.getSuccessorOperands(it.getSuccessorIndex());
2089432fbfeSMogball       for (auto &it : llvm::enumerate(argLattices)) {
2099432fbfeSMogball         if (Value operand = operands[it.index()]) {
2109432fbfeSMogball           join(it.value(), *getLatticeElementFor(block, operand));
2119432fbfeSMogball         } else {
2129432fbfeSMogball           // Conservatively mark internally produced arguments as having reached
2139432fbfeSMogball           // their pessimistic fixpoint.
2149432fbfeSMogball           markAllPessimisticFixpoint(it.value());
2159432fbfeSMogball         }
2169432fbfeSMogball       }
2179432fbfeSMogball     } else {
2189432fbfeSMogball       return markAllPessimisticFixpoint(argLattices);
2199432fbfeSMogball     }
2209432fbfeSMogball   }
2219432fbfeSMogball }
2229432fbfeSMogball 
visitRegionSuccessors(ProgramPoint point,RegionBranchOpInterface branch,Optional<unsigned> successorIndex,ArrayRef<AbstractSparseLattice * > lattices)2239432fbfeSMogball void AbstractSparseDataFlowAnalysis::visitRegionSuccessors(
2249432fbfeSMogball     ProgramPoint point, RegionBranchOpInterface branch,
2259432fbfeSMogball     Optional<unsigned> successorIndex,
2269432fbfeSMogball     ArrayRef<AbstractSparseLattice *> lattices) {
2279432fbfeSMogball   const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
2289432fbfeSMogball   assert(predecessors->allPredecessorsKnown() &&
2299432fbfeSMogball          "unexpected unresolved region successors");
2309432fbfeSMogball 
2319432fbfeSMogball   for (Operation *op : predecessors->getKnownPredecessors()) {
2329432fbfeSMogball     // Get the incoming successor operands.
2339432fbfeSMogball     Optional<OperandRange> operands;
2349432fbfeSMogball 
2359432fbfeSMogball     // Check if the predecessor is the parent op.
2369432fbfeSMogball     if (op == branch) {
2379432fbfeSMogball       operands = branch.getSuccessorEntryOperands(successorIndex);
2389432fbfeSMogball       // Otherwise, try to deduce the operands from a region return-like op.
2399432fbfeSMogball     } else {
2409432fbfeSMogball       if (isRegionReturnLike(op))
2419432fbfeSMogball         operands = getRegionBranchSuccessorOperands(op, successorIndex);
2429432fbfeSMogball     }
2439432fbfeSMogball 
2449432fbfeSMogball     if (!operands) {
2459432fbfeSMogball       // We can't reason about the data-flow.
2469432fbfeSMogball       return markAllPessimisticFixpoint(lattices);
2479432fbfeSMogball     }
2489432fbfeSMogball 
2499432fbfeSMogball     ValueRange inputs = predecessors->getSuccessorInputs(op);
2509432fbfeSMogball     assert(inputs.size() == operands->size() &&
2519432fbfeSMogball            "expected the same number of successor inputs as operands");
2529432fbfeSMogball 
2539432fbfeSMogball     unsigned firstIndex = 0;
2549432fbfeSMogball     if (inputs.size() != lattices.size()) {
255*ab701975SMogball       if (auto *op = point.dyn_cast<Operation *>()) {
256*ab701975SMogball         if (!inputs.empty())
257*ab701975SMogball           firstIndex = inputs.front().cast<OpResult>().getResultNumber();
258*ab701975SMogball         visitNonControlFlowArgumentsImpl(
259*ab701975SMogball             branch,
260*ab701975SMogball             RegionSuccessor(
261*ab701975SMogball                 branch->getResults().slice(firstIndex, inputs.size())),
262*ab701975SMogball             lattices, firstIndex);
263*ab701975SMogball       } else {
264*ab701975SMogball         if (!inputs.empty())
2659432fbfeSMogball           firstIndex = inputs.front().cast<BlockArgument>().getArgNumber();
266*ab701975SMogball         Region *region = point.get<Block *>()->getParent();
267*ab701975SMogball         visitNonControlFlowArgumentsImpl(
268*ab701975SMogball             branch,
269*ab701975SMogball             RegionSuccessor(region, region->getArguments().slice(
270*ab701975SMogball                                         firstIndex, inputs.size())),
271*ab701975SMogball             lattices, firstIndex);
272*ab701975SMogball       }
2739432fbfeSMogball     }
2749432fbfeSMogball 
2759432fbfeSMogball     for (auto it : llvm::zip(*operands, lattices.drop_front(firstIndex)))
2769432fbfeSMogball       join(std::get<1>(it), *getLatticeElementFor(point, std::get<0>(it)));
2779432fbfeSMogball   }
2789432fbfeSMogball }
2799432fbfeSMogball 
2809432fbfeSMogball const AbstractSparseLattice *
getLatticeElementFor(ProgramPoint point,Value value)2819432fbfeSMogball AbstractSparseDataFlowAnalysis::getLatticeElementFor(ProgramPoint point,
2829432fbfeSMogball                                                      Value value) {
2839432fbfeSMogball   AbstractSparseLattice *state = getLatticeElement(value);
2849432fbfeSMogball   addDependency(state, point);
2859432fbfeSMogball   return state;
2869432fbfeSMogball }
2879432fbfeSMogball 
markAllPessimisticFixpoint(ArrayRef<AbstractSparseLattice * > lattices)2889432fbfeSMogball void AbstractSparseDataFlowAnalysis::markAllPessimisticFixpoint(
2899432fbfeSMogball     ArrayRef<AbstractSparseLattice *> lattices) {
2909432fbfeSMogball   for (AbstractSparseLattice *lattice : lattices)
2919432fbfeSMogball     propagateIfChanged(lattice, lattice->markPessimisticFixpoint());
2929432fbfeSMogball }
2939432fbfeSMogball 
join(AbstractSparseLattice * lhs,const AbstractSparseLattice & rhs)2949432fbfeSMogball void AbstractSparseDataFlowAnalysis::join(AbstractSparseLattice *lhs,
2959432fbfeSMogball                                           const AbstractSparseLattice &rhs) {
2969432fbfeSMogball   propagateIfChanged(lhs, lhs->join(rhs));
2979432fbfeSMogball }
298