1# Writing DataFlow Analyses in MLIR
2
3Writing dataflow analyses in MLIR, or well any compiler, can often seem quite
4daunting and/or complex. A dataflow analysis generally involves propagating
5information about the IR across various different types of control flow
6constructs, of which MLIR has many (Block-based branches, Region-based branches,
7CallGraph, etc), and it isn't always clear how best to go about performing the
8propagation. To help writing these types of analyses in MLIR, this document
9details several utilities that simplify the process and make it a bit more
10approachable.
11
12## Forward Dataflow Analysis
13
14One type of dataflow analysis is a forward propagation analysis. This type of
15analysis, as the name may suggest, propagates information forward (e.g. from
16definitions to uses). To provide a bit of concrete context, let's go over
17writing a simple forward dataflow analysis in MLIR. Let's say for this analysis
18that we want to propagate information about a special "metadata" dictionary
19attribute. The contents of this attribute are simply a set of metadata that
20describe a specific value, e.g. `metadata = { likes_pizza = true }`. We will
21collect the `metadata` for operations in the IR and propagate them about.
22
23### Lattices
24
25Before going into how one might setup the analysis itself, it is important to
26first introduce the concept of a `Lattice` and how we will use it for the
27analysis. A lattice represents all of the possible values or results of the
28analysis for a given value. A lattice element holds the set of information
29computed by the analysis for a given value, and is what gets propagated across
30the IR. For our analysis, this would correspond to the `metadata` dictionary
31attribute.
32
33Regardless of the value held within, every type of lattice contains two special
34element states:
35
36*   `uninitialized`
37
38    -   The element has not been initialized.
39
40*   `top`/`overdefined`/`unknown`
41
42    -   The element encompasses every possible value.
43    -   This is a very conservative state, and essentially means "I can't make
44        any assumptions about the value, it could be anything"
45
46These two states are important when merging, or `join`ing as we will refer to it
47further in this document, information as part of the analysis. Lattice elements
48are `join`ed whenever there are two different source points, such as an argument
49to a block with multiple predecessors. One important note about the `join`
50operation, is that it is required to be monotonic (see the `join` method in the
51example below for more information). This ensures that `join`ing elements is
52consistent. The two special states mentioned above have unique properties during
53a `join`:
54
55*   `uninitialized`
56
57    -   If one of the elements is `uninitialized`, the other element is used.
58    -   `uninitialized` in the context of a `join` essentially means "take the
59        other thing".
60
61*   `top`/`overdefined`/`unknown`
62
63    -   If one of the elements being joined is `overdefined`, the result is
64        `overdefined`.
65
66For our analysis in MLIR, we will need to define a class representing the value
67held by an element of the lattice used by our dataflow analysis:
68
69```c++
70/// The value of our lattice represents the inner structure of a DictionaryAttr,
71/// for the `metadata`.
72struct MetadataLatticeValue {
73  MetadataLatticeValue() = default;
74  /// Compute a lattice value from the provided dictionary.
75  MetadataLatticeValue(DictionaryAttr attr)
76      : metadata(attr.begin(), attr.end()) {}
77
78  /// Return a pessimistic value state, i.e. the `top`/`overdefined`/`unknown`
79  /// state, for our value type. The resultant state should not assume any
80  /// information about the state of the IR.
81  static MetadataLatticeValue getPessimisticValueState(MLIRContext *context) {
82    // The `top`/`overdefined`/`unknown` state is when we know nothing about any
83    // metadata, i.e. an empty dictionary.
84    return MetadataLatticeValue();
85  }
86  /// Return a pessimistic value state for our value type using only information
87  /// about the state of the provided IR. This is similar to the above method,
88  /// but may produce a slightly more refined result. This is okay, as the
89  /// information is already encoded as fact in the IR.
90  static MetadataLatticeValue getPessimisticValueState(Value value) {
91    // Check to see if the parent operation has metadata.
92    if (Operation *parentOp = value.getDefiningOp()) {
93      if (auto metadata = parentOp->getAttrOfType<DictionaryAttr>("metadata"))
94        return MetadataLatticeValue(metadata);
95
96      // If no metadata is present, fallback to the
97      // `top`/`overdefined`/`unknown` state.
98    }
99    return MetadataLatticeValue();
100  }
101
102  /// This method conservatively joins the information held by `lhs` and `rhs`
103  /// into a new value. This method is required to be monotonic. `monotonicity`
104  /// is implied by the satisfaction of the following axioms:
105  ///   * idempotence:   join(x,x) == x
106  ///   * commutativity: join(x,y) == join(y,x)
107  ///   * associativity: join(x,join(y,z)) == join(join(x,y),z)
108  ///
109  /// When the above axioms are satisfied, we achieve `monotonicity`:
110  ///   * monotonicity: join(x, join(x,y)) == join(x,y)
111  static MetadataLatticeValue join(const MetadataLatticeValue &lhs,
112                                   const MetadataLatticeValue &rhs) {
113    // To join `lhs` and `rhs` we will define a simple policy, which is that we
114    // only keep information that is the same. This means that we only keep
115    // facts that are true in both.
116    MetadataLatticeValue result;
117    for (const auto &lhsIt : lhs) {
118      // As noted above, we only merge if the values are the same.
119      auto it = rhs.metadata.find(lhsIt.first);
120      if (it == rhs.metadata.end() || it->second != lhsIt.second)
121        continue;
122      result.insert(lhsIt);
123    }
124    return result;
125  }
126
127  /// A simple comparator that checks to see if this value is equal to the one
128  /// provided.
129  bool operator==(const MetadataLatticeValue &rhs) const {
130    if (metadata.size() != rhs.metadata.size())
131      return false;
132    // Check that the 'rhs' contains the same metadata.
133    return llvm::all_of(metadata, [&](auto &it) {
134      return rhs.metadata.count(it.second);
135    });
136  }
137
138  /// Our value represents the combined metadata, which is originally a
139  /// DictionaryAttr, so we use a map.
140  DenseMap<StringAttr, Attribute> metadata;
141};
142```
143
144One interesting thing to note above is that we don't have an explicit method for
145the `uninitialized` state. This state is handled by the `LatticeElement` class,
146which manages a lattice value for a given IR entity. A quick overview of this
147class, and the API that will be interesting to us while writing our analysis, is
148shown below:
149
150```c++
151/// This class represents a lattice element holding a specific value of type
152/// `ValueT`.
153template <typename ValueT>
154class LatticeElement ... {
155public:
156  /// Return the value held by this element. This requires that a value is
157  /// known, i.e. not `uninitialized`.
158  ValueT &getValue();
159  const ValueT &getValue() const;
160
161  /// Join the information contained in the 'rhs' element into this
162  /// element. Returns if the state of the current element changed.
163  ChangeResult join(const LatticeElement<ValueT> &rhs);
164
165  /// Join the information contained in the 'rhs' value into this
166  /// lattice. Returns if the state of the current lattice changed.
167  ChangeResult join(const ValueT &rhs);
168
169  /// Mark the lattice element as having reached a pessimistic fixpoint. This
170  /// means that the lattice may potentially have conflicting value states, and
171  /// only the conservatively known value state should be relied on.
172  ChangeResult markPessimisticFixPoint();
173};
174```
175
176With our lattice defined, we can now define the driver that will compute and
177propagate our lattice across the IR.
178
179### ForwardDataflowAnalysis Driver
180
181The `ForwardDataFlowAnalysis` class represents the driver of the dataflow
182analysis, and performs all of the related analysis computation. When defining
183our analysis, we will inherit from this class and implement some of its hooks.
184Before that, let's look at a quick overview of this class and some of the
185important API for our analysis:
186
187```c++
188/// This class represents the main driver of the forward dataflow analysis. It
189/// takes as a template parameter the value type of lattice being computed.
190template <typename ValueT>
191class ForwardDataFlowAnalysis : ... {
192public:
193  ForwardDataFlowAnalysis(MLIRContext *context);
194
195  /// Compute the analysis on operations rooted under the given top-level
196  /// operation. Note that the top-level operation is not visited.
197  void run(Operation *topLevelOp);
198
199  /// Return the lattice element attached to the given value. If a lattice has
200  /// not been added for the given value, a new 'uninitialized' value is
201  /// inserted and returned.
202  LatticeElement<ValueT> &getLatticeElement(Value value);
203
204  /// Return the lattice element attached to the given value, or nullptr if no
205  /// lattice element for the value has yet been created.
206  LatticeElement<ValueT> *lookupLatticeElement(Value value);
207
208  /// Mark all of the lattice elements for the given range of Values as having
209  /// reached a pessimistic fixpoint.
210  ChangeResult markAllPessimisticFixPoint(ValueRange values);
211
212protected:
213  /// Visit the given operation, and join any necessary analysis state
214  /// into the lattice elements for the results and block arguments owned by
215  /// this operation using the provided set of operand lattice elements
216  /// (all pointer values are guaranteed to be non-null). Returns if any result
217  /// or block argument value lattice elements changed during the visit. The
218  /// lattice element for a result or block argument value can be obtained, and
219  /// join'ed into, by using `getLatticeElement`.
220  virtual ChangeResult visitOperation(
221      Operation *op, ArrayRef<LatticeElement<ValueT> *> operands) = 0;
222};
223```
224
225NOTE: Some API has been redacted for our example. The `ForwardDataFlowAnalysis`
226contains various other hooks that allow for injecting custom behavior when
227applicable.
228
229The main API that we are responsible for defining is the `visitOperation`
230method. This method is responsible for computing new lattice elements for the
231results and block arguments owned by the given operation. This is where we will
232inject the lattice element computation logic, also known as the transfer
233function for the operation, that is specific to our analysis. A simple
234implementation for our example is shown below:
235
236```c++
237class MetadataAnalysis : public ForwardDataFlowAnalysis<MetadataLatticeValue> {
238public:
239  using ForwardDataFlowAnalysis<MetadataLatticeValue>::ForwardDataFlowAnalysis;
240
241  ChangeResult visitOperation(
242      Operation *op, ArrayRef<LatticeElement<ValueT> *> operands) override {
243    DictionaryAttr metadata = op->getAttrOfType<DictionaryAttr>("metadata");
244
245    // If we have no metadata for this operation, we will conservatively mark
246    // all of the results as having reached a pessimistic fixpoint.
247    if (!metadata)
248      return markAllPessimisticFixPoint(op->getResults());
249
250    // Otherwise, we will compute a lattice value for the metadata and join it
251    // into the current lattice element for all of our results.
252    MetadataLatticeValue latticeValue(metadata);
253    ChangeResult result = ChangeResult::NoChange;
254    for (Value value : op->getResults()) {
255      // We grab the lattice element for `value` via `getLatticeElement` and
256      // then join it with the lattice value for this operation's metadata. Note
257      // that during the analysis phase, it is fine to freely create a new
258      // lattice element for a value. This is why we don't use the
259      // `lookupLatticeElement` method here.
260      result |= getLatticeElement(value).join(latticeValue);
261    }
262    return result;
263  }
264};
265```
266
267With that, we have all of the necessary components to compute our analysis.
268After the analysis has been computed, we can grab any computed information for
269values by using `lookupLatticeElement`. We use this function over
270`getLatticeElement` as the analysis is not guaranteed to visit all values, e.g.
271if the value is in a unreachable block, and we don't want to create a new
272uninitialized lattice element in this case. See below for a quick example:
273
274```c++
275void MyPass::runOnOperation() {
276  MetadataAnalysis analysis(&getContext());
277  analysis.run(getOperation());
278  ...
279}
280
281void MyPass::useAnalysisOn(MetadataAnalysis &analysis, Value value) {
282  LatticeElement<MetadataLatticeValue> *latticeElement = analysis.lookupLatticeElement(value);
283
284  // If we don't have an element, the `value` wasn't visited during our analysis
285  // meaning that it could be dead. We need to treat this conservatively.
286  if (!lattice)
287    return;
288
289  // Our lattice element has a value, use it:
290  MetadataLatticeValue &value = lattice->getValue();
291  ...
292}
293```
294