1 //===- ReductionNode.h - Reduction Node Implementation ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the reduction nodes which are used to track of the metadata
10 // for a specific generated variant within a reduction pass and are the building
11 // blocks of the reduction tree structure. A reduction tree is used to keep
12 // track of the different generated variants throughout a reduction pass in the
13 // MLIR Reduce tool.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #ifndef MLIR_REDUCER_REDUCTIONNODE_H
18 #define MLIR_REDUCER_REDUCTIONNODE_H
19 
20 #include <queue>
21 #include <vector>
22 
23 #include "mlir/IR/OwningOpRef.h"
24 #include "mlir/Reducer/Tester.h"
25 #include "mlir/Support/LogicalResult.h"
26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/Support/Allocator.h"
28 #include "llvm/Support/ToolOutputFile.h"
29 
30 namespace mlir {
31 
32 class ModuleOp;
33 class Region;
34 
35 /// Defines the traversal method options to be used in the reduction tree
36 /// traversal.
37 enum TraversalMode { SinglePath, Backtrack, MultiPath };
38 
39 /// ReductionTreePass will build a reduction tree during module reduction and
40 /// the ReductionNode represents the vertex of the tree. A ReductionNode records
41 /// the information such as the reduced module, how this node is reduced from
42 /// the parent node, etc. This information will be used to construct a reduction
43 /// path to reduce the certain module.
44 class ReductionNode {
45 public:
46   template <TraversalMode mode>
47   class iterator;
48 
49   using Range = std::pair<int, int>;
50 
51   ReductionNode(ReductionNode *parent, const std::vector<Range> &range,
52                 llvm::SpecificBumpPtrAllocator<ReductionNode> &allocator);
53 
getParent()54   ReductionNode *getParent() const { return parent; }
55 
56   /// If the ReductionNode hasn't been tested the interestingness, it'll be the
57   /// same module as the one in the parent node. Otherwise, the returned module
58   /// will have been applied certain reduction strategies. Note that it's not
59   /// necessary to be an interesting case or a reduced module (has smaller size
60   /// than parent's).
getModule()61   ModuleOp getModule() const { return module.get(); }
62 
63   /// Return the region we're reducing.
getRegion()64   Region &getRegion() const { return *region; }
65 
66   /// Return the size of the module.
getSize()67   size_t getSize() const { return size; }
68 
69   /// Returns true if the module exhibits the interesting behavior.
isInteresting()70   Tester::Interestingness isInteresting() const { return interesting; }
71 
72   /// Return the range information that how this node is reduced from the parent
73   /// node.
getStartRanges()74   ArrayRef<Range> getStartRanges() const { return startRanges; }
75 
76   /// Return the range set we are using to generate variants.
getRanges()77   ArrayRef<Range> getRanges() const { return ranges; }
78 
79   /// Return the generated variants(the child nodes).
getVariants()80   ArrayRef<ReductionNode *> getVariants() const { return variants; }
81 
82   /// Split the ranges and generate new variants.
83   ArrayRef<ReductionNode *> generateNewVariants();
84 
85   /// Update the interestingness result from tester.
86   void update(std::pair<Tester::Interestingness, size_t> result);
87 
88   /// Each Reduction Node contains a copy of module for applying rewrite
89   /// patterns. In addition, we only apply rewrite patterns in a certain region.
90   /// In init(), we will duplicate the module from parent node and locate the
91   /// corresponding region.
92   LogicalResult initialize(ModuleOp parentModule, Region &parentRegion);
93 
94 private:
95   /// A custom BFS iterator. The difference between
96   /// llvm/ADT/BreadthFirstIterator.h is the graph we're exploring is dynamic.
97   /// We may explore more neighbors at certain node if we didn't find interested
98   /// event. As a result, we defer pushing adjacent nodes until poping the last
99   /// visited node. The graph exploration strategy will be put in
100   /// getNeighbors().
101   ///
102   /// Subclass BaseIterator and implement traversal strategy in getNeighbors().
103   template <typename T>
104   class BaseIterator {
105   public:
BaseIterator(ReductionNode * node)106     BaseIterator(ReductionNode *node) { visitQueue.push(node); }
107     BaseIterator(const BaseIterator &) = default;
108     BaseIterator() = default;
109 
end()110     static BaseIterator end() { return BaseIterator(); }
111 
112     bool operator==(const BaseIterator &i) {
113       return visitQueue == i.visitQueue;
114     }
115     bool operator!=(const BaseIterator &i) { return !(*this == i); }
116 
117     BaseIterator &operator++() {
118       ReductionNode *top = visitQueue.front();
119       visitQueue.pop();
120       for (ReductionNode *node : getNeighbors(top))
121         visitQueue.push(node);
122       return *this;
123     }
124 
125     BaseIterator operator++(int) {
126       BaseIterator tmp = *this;
127       ++*this;
128       return tmp;
129     }
130 
131     ReductionNode &operator*() const { return *(visitQueue.front()); }
132     ReductionNode *operator->() const { return visitQueue.front(); }
133 
134   protected:
getNeighbors(ReductionNode * node)135     ArrayRef<ReductionNode *> getNeighbors(ReductionNode *node) {
136       return static_cast<T *>(this)->getNeighbors(node);
137     }
138 
139   private:
140     std::queue<ReductionNode *> visitQueue;
141   };
142 
143   /// This is a copy of module from parent node. All the reducer patterns will
144   /// be applied to this instance.
145   OwningOpRef<ModuleOp> module;
146 
147   /// The region of certain operation we're reducing in the module
148   Region *region = nullptr;
149 
150   /// The node we are reduced from. It means we will be in variants of parent
151   /// node.
152   ReductionNode *parent = nullptr;
153 
154   /// The size of module after applying the reducer patterns with range
155   /// constraints. This is only valid while the interestingness has been tested.
156   size_t size = 0;
157 
158   /// This is true if the module has been evaluated and it exhibits the
159   /// interesting behavior.
160   Tester::Interestingness interesting = Tester::Interestingness::Untested;
161 
162   /// `ranges` represents the selected subset of operations in the region. We
163   /// implicitly number each operation in the region and ReductionTreePass will
164   /// apply reducer patterns on the operation falls into the `ranges`. We will
165   /// generate new ReductionNode with subset of `ranges` to see if we can do
166   /// further reduction. we may split the element in the `ranges` so that we can
167   /// have more subset variants from `ranges`.
168   /// Note that after applying the reducer patterns the number of operation in
169   /// the region may have changed, we need to update the `ranges` after that.
170   std::vector<Range> ranges;
171 
172   /// `startRanges` records the ranges of operations selected from the parent
173   /// node to produce this ReductionNode. It can be used to construct the
174   /// reduction path from the root. I.e., if we apply the same reducer patterns
175   /// and `startRanges` selection on the parent region, we will get the same
176   /// module as this node.
177   const std::vector<Range> startRanges;
178 
179   /// This points to the child variants that were created using this node as a
180   /// starting point.
181   std::vector<ReductionNode *> variants;
182 
183   llvm::SpecificBumpPtrAllocator<ReductionNode> &allocator;
184 };
185 
186 // Specialized iterator for SinglePath traversal
187 template <>
188 class ReductionNode::iterator<SinglePath>
189     : public BaseIterator<iterator<SinglePath>> {
190   friend BaseIterator<iterator<SinglePath>>;
191   using BaseIterator::BaseIterator;
192   ArrayRef<ReductionNode *> getNeighbors(ReductionNode *node);
193 };
194 
195 } // namespace mlir
196 
197 #endif // MLIR_REDUCER_REDUCTIONNODE_H
198