1 //===- ReductionNode.h - Reduction Node Implementation ----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the reduction nodes which are used to track of the metadata 10 // for a specific generated variant within a reduction pass and are the building 11 // blocks of the reduction tree structure. A reduction tree is used to keep 12 // track of the different generated variants throughout a reduction pass in the 13 // MLIR Reduce tool. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #ifndef MLIR_REDUCER_REDUCTIONNODE_H 18 #define MLIR_REDUCER_REDUCTIONNODE_H 19 20 #include <queue> 21 #include <vector> 22 23 #include "mlir/IR/OwningOpRef.h" 24 #include "mlir/Reducer/Tester.h" 25 #include "mlir/Support/LogicalResult.h" 26 #include "llvm/ADT/ArrayRef.h" 27 #include "llvm/Support/Allocator.h" 28 #include "llvm/Support/ToolOutputFile.h" 29 30 namespace mlir { 31 32 class ModuleOp; 33 class Region; 34 35 /// Defines the traversal method options to be used in the reduction tree 36 /// traversal. 37 enum TraversalMode { SinglePath, Backtrack, MultiPath }; 38 39 /// ReductionTreePass will build a reduction tree during module reduction and 40 /// the ReductionNode represents the vertex of the tree. A ReductionNode records 41 /// the information such as the reduced module, how this node is reduced from 42 /// the parent node, etc. This information will be used to construct a reduction 43 /// path to reduce the certain module. 44 class ReductionNode { 45 public: 46 template <TraversalMode mode> 47 class iterator; 48 49 using Range = std::pair<int, int>; 50 51 ReductionNode(ReductionNode *parent, const std::vector<Range> &range, 52 llvm::SpecificBumpPtrAllocator<ReductionNode> &allocator); 53 getParent()54 ReductionNode *getParent() const { return parent; } 55 56 /// If the ReductionNode hasn't been tested the interestingness, it'll be the 57 /// same module as the one in the parent node. Otherwise, the returned module 58 /// will have been applied certain reduction strategies. Note that it's not 59 /// necessary to be an interesting case or a reduced module (has smaller size 60 /// than parent's). getModule()61 ModuleOp getModule() const { return module.get(); } 62 63 /// Return the region we're reducing. getRegion()64 Region &getRegion() const { return *region; } 65 66 /// Return the size of the module. getSize()67 size_t getSize() const { return size; } 68 69 /// Returns true if the module exhibits the interesting behavior. isInteresting()70 Tester::Interestingness isInteresting() const { return interesting; } 71 72 /// Return the range information that how this node is reduced from the parent 73 /// node. getStartRanges()74 ArrayRef<Range> getStartRanges() const { return startRanges; } 75 76 /// Return the range set we are using to generate variants. getRanges()77 ArrayRef<Range> getRanges() const { return ranges; } 78 79 /// Return the generated variants(the child nodes). getVariants()80 ArrayRef<ReductionNode *> getVariants() const { return variants; } 81 82 /// Split the ranges and generate new variants. 83 ArrayRef<ReductionNode *> generateNewVariants(); 84 85 /// Update the interestingness result from tester. 86 void update(std::pair<Tester::Interestingness, size_t> result); 87 88 /// Each Reduction Node contains a copy of module for applying rewrite 89 /// patterns. In addition, we only apply rewrite patterns in a certain region. 90 /// In init(), we will duplicate the module from parent node and locate the 91 /// corresponding region. 92 LogicalResult initialize(ModuleOp parentModule, Region &parentRegion); 93 94 private: 95 /// A custom BFS iterator. The difference between 96 /// llvm/ADT/BreadthFirstIterator.h is the graph we're exploring is dynamic. 97 /// We may explore more neighbors at certain node if we didn't find interested 98 /// event. As a result, we defer pushing adjacent nodes until poping the last 99 /// visited node. The graph exploration strategy will be put in 100 /// getNeighbors(). 101 /// 102 /// Subclass BaseIterator and implement traversal strategy in getNeighbors(). 103 template <typename T> 104 class BaseIterator { 105 public: BaseIterator(ReductionNode * node)106 BaseIterator(ReductionNode *node) { visitQueue.push(node); } 107 BaseIterator(const BaseIterator &) = default; 108 BaseIterator() = default; 109 end()110 static BaseIterator end() { return BaseIterator(); } 111 112 bool operator==(const BaseIterator &i) { 113 return visitQueue == i.visitQueue; 114 } 115 bool operator!=(const BaseIterator &i) { return !(*this == i); } 116 117 BaseIterator &operator++() { 118 ReductionNode *top = visitQueue.front(); 119 visitQueue.pop(); 120 for (ReductionNode *node : getNeighbors(top)) 121 visitQueue.push(node); 122 return *this; 123 } 124 125 BaseIterator operator++(int) { 126 BaseIterator tmp = *this; 127 ++*this; 128 return tmp; 129 } 130 131 ReductionNode &operator*() const { return *(visitQueue.front()); } 132 ReductionNode *operator->() const { return visitQueue.front(); } 133 134 protected: getNeighbors(ReductionNode * node)135 ArrayRef<ReductionNode *> getNeighbors(ReductionNode *node) { 136 return static_cast<T *>(this)->getNeighbors(node); 137 } 138 139 private: 140 std::queue<ReductionNode *> visitQueue; 141 }; 142 143 /// This is a copy of module from parent node. All the reducer patterns will 144 /// be applied to this instance. 145 OwningOpRef<ModuleOp> module; 146 147 /// The region of certain operation we're reducing in the module 148 Region *region = nullptr; 149 150 /// The node we are reduced from. It means we will be in variants of parent 151 /// node. 152 ReductionNode *parent = nullptr; 153 154 /// The size of module after applying the reducer patterns with range 155 /// constraints. This is only valid while the interestingness has been tested. 156 size_t size = 0; 157 158 /// This is true if the module has been evaluated and it exhibits the 159 /// interesting behavior. 160 Tester::Interestingness interesting = Tester::Interestingness::Untested; 161 162 /// `ranges` represents the selected subset of operations in the region. We 163 /// implicitly number each operation in the region and ReductionTreePass will 164 /// apply reducer patterns on the operation falls into the `ranges`. We will 165 /// generate new ReductionNode with subset of `ranges` to see if we can do 166 /// further reduction. we may split the element in the `ranges` so that we can 167 /// have more subset variants from `ranges`. 168 /// Note that after applying the reducer patterns the number of operation in 169 /// the region may have changed, we need to update the `ranges` after that. 170 std::vector<Range> ranges; 171 172 /// `startRanges` records the ranges of operations selected from the parent 173 /// node to produce this ReductionNode. It can be used to construct the 174 /// reduction path from the root. I.e., if we apply the same reducer patterns 175 /// and `startRanges` selection on the parent region, we will get the same 176 /// module as this node. 177 const std::vector<Range> startRanges; 178 179 /// This points to the child variants that were created using this node as a 180 /// starting point. 181 std::vector<ReductionNode *> variants; 182 183 llvm::SpecificBumpPtrAllocator<ReductionNode> &allocator; 184 }; 185 186 // Specialized iterator for SinglePath traversal 187 template <> 188 class ReductionNode::iterator<SinglePath> 189 : public BaseIterator<iterator<SinglePath>> { 190 friend BaseIterator<iterator<SinglePath>>; 191 using BaseIterator::BaseIterator; 192 ArrayRef<ReductionNode *> getNeighbors(ReductionNode *node); 193 }; 194 195 } // namespace mlir 196 197 #endif // MLIR_REDUCER_REDUCTIONNODE_H 198