1 //===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the Reduction Tree Pass class. It provides a framework for
10 // the implementation of different reduction passes in the MLIR Reduce tool. It
11 // allows for custom specification of the variant generation behavior. It
12 // implements methods that define the different possible traversals of the
13 // reduction tree.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #include "mlir/IR/DialectInterface.h"
18 #include "mlir/IR/OpDefinition.h"
19 #include "mlir/Reducer/PassDetail.h"
20 #include "mlir/Reducer/Passes.h"
21 #include "mlir/Reducer/ReductionNode.h"
22 #include "mlir/Reducer/ReductionPatternInterface.h"
23 #include "mlir/Reducer/Tester.h"
24 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26
27 #include "llvm/ADT/ArrayRef.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/Support/Allocator.h"
30 #include "llvm/Support/ManagedStatic.h"
31
32 using namespace mlir;
33
34 /// We implicitly number each operation in the region and if an operation's
35 /// number falls into rangeToKeep, we need to keep it and apply the given
36 /// rewrite patterns on it.
applyPatterns(Region & region,const FrozenRewritePatternSet & patterns,ArrayRef<ReductionNode::Range> rangeToKeep,bool eraseOpNotInRange)37 static void applyPatterns(Region ®ion,
38 const FrozenRewritePatternSet &patterns,
39 ArrayRef<ReductionNode::Range> rangeToKeep,
40 bool eraseOpNotInRange) {
41 std::vector<Operation *> opsNotInRange;
42 std::vector<Operation *> opsInRange;
43 size_t keepIndex = 0;
44 for (const auto &op : enumerate(region.getOps())) {
45 int index = op.index();
46 if (keepIndex < rangeToKeep.size() &&
47 index == rangeToKeep[keepIndex].second)
48 ++keepIndex;
49 if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first)
50 opsNotInRange.push_back(&op.value());
51 else
52 opsInRange.push_back(&op.value());
53 }
54
55 // `applyOpPatternsAndFold` may erase the ops so we can't do the pattern
56 // matching in above iteration. Besides, erase op not-in-range may end up in
57 // invalid module, so `applyOpPatternsAndFold` should come before that
58 // transform.
59 for (Operation *op : opsInRange)
60 // `applyOpPatternsAndFold` returns whether the op is convered. Omit it
61 // because we don't have expectation this reduction will be success or not.
62 (void)applyOpPatternsAndFold(op, patterns);
63
64 if (eraseOpNotInRange)
65 for (Operation *op : opsNotInRange) {
66 op->dropAllUses();
67 op->erase();
68 }
69 }
70
71 /// We will apply the reducer patterns to the operations in the ranges specified
72 /// by ReductionNode. Note that we are not able to remove an operation without
73 /// replacing it with another valid operation. However, The validity of module
74 /// reduction is based on the Tester provided by the user and that means certain
75 /// invalid module is still interested by the use. Thus we provide an
76 /// alternative way to remove operations, which is using `eraseOpNotInRange` to
77 /// erase the operations not in the range specified by ReductionNode.
78 template <typename IteratorType>
findOptimal(ModuleOp module,Region & region,const FrozenRewritePatternSet & patterns,const Tester & test,bool eraseOpNotInRange)79 static LogicalResult findOptimal(ModuleOp module, Region ®ion,
80 const FrozenRewritePatternSet &patterns,
81 const Tester &test, bool eraseOpNotInRange) {
82 std::pair<Tester::Interestingness, size_t> initStatus =
83 test.isInteresting(module);
84 // While exploring the reduction tree, we always branch from an interesting
85 // node. Thus the root node must be interesting.
86 if (initStatus.first != Tester::Interestingness::True)
87 return module.emitWarning() << "uninterested module will not be reduced";
88
89 llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
90
91 std::vector<ReductionNode::Range> ranges{
92 {0, std::distance(region.op_begin(), region.op_end())}};
93
94 ReductionNode *root = allocator.Allocate();
95 new (root) ReductionNode(nullptr, ranges, allocator);
96 // Duplicate the module for root node and locate the region in the copy.
97 if (failed(root->initialize(module, region)))
98 llvm_unreachable("unexpected initialization failure");
99 root->update(initStatus);
100
101 ReductionNode *smallestNode = root;
102 IteratorType iter(root);
103
104 while (iter != IteratorType::end()) {
105 ReductionNode ¤tNode = *iter;
106 Region &curRegion = currentNode.getRegion();
107
108 applyPatterns(curRegion, patterns, currentNode.getRanges(),
109 eraseOpNotInRange);
110 currentNode.update(test.isInteresting(currentNode.getModule()));
111
112 if (currentNode.isInteresting() == Tester::Interestingness::True &&
113 currentNode.getSize() < smallestNode->getSize())
114 smallestNode = ¤tNode;
115
116 ++iter;
117 }
118
119 // At here, we have found an optimal path to reduce the given region. Retrieve
120 // the path and apply the reducer to it.
121 SmallVector<ReductionNode *> trace;
122 ReductionNode *curNode = smallestNode;
123 trace.push_back(curNode);
124 while (curNode != root) {
125 curNode = curNode->getParent();
126 trace.push_back(curNode);
127 }
128
129 // Reduce the region through the optimal path.
130 while (!trace.empty()) {
131 ReductionNode *top = trace.pop_back_val();
132 applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange);
133 }
134
135 if (test.isInteresting(module).first != Tester::Interestingness::True)
136 llvm::report_fatal_error("Reduced module is not interesting");
137 if (test.isInteresting(module).second != smallestNode->getSize())
138 llvm::report_fatal_error(
139 "Reduced module doesn't have consistent size with smallestNode");
140 return success();
141 }
142
143 template <typename IteratorType>
findOptimal(ModuleOp module,Region & region,const FrozenRewritePatternSet & patterns,const Tester & test)144 static LogicalResult findOptimal(ModuleOp module, Region ®ion,
145 const FrozenRewritePatternSet &patterns,
146 const Tester &test) {
147 // We separate the reduction process into 2 steps, the first one is to erase
148 // redundant operations and the second one is to apply the reducer patterns.
149
150 // In the first phase, we don't apply any patterns so that we only select the
151 // range of operations to keep to the module stay interesting.
152 if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
153 /*eraseOpNotInRange=*/true)))
154 return failure();
155 // In the second phase, we suppose that no operation is redundant, so we try
156 // to rewrite the operation into simpler form.
157 return findOptimal<IteratorType>(module, region, patterns, test,
158 /*eraseOpNotInRange=*/false);
159 }
160
161 namespace {
162
163 //===----------------------------------------------------------------------===//
164 // Reduction Pattern Interface Collection
165 //===----------------------------------------------------------------------===//
166
167 class ReductionPatternInterfaceCollection
168 : public DialectInterfaceCollection<DialectReductionPatternInterface> {
169 public:
170 using Base::Base;
171
172 // Collect the reduce patterns defined by each dialect.
populateReductionPatterns(RewritePatternSet & pattern) const173 void populateReductionPatterns(RewritePatternSet &pattern) const {
174 for (const DialectReductionPatternInterface &interface : *this)
175 interface.populateReductionPatterns(pattern);
176 }
177 };
178
179 //===----------------------------------------------------------------------===//
180 // ReductionTreePass
181 //===----------------------------------------------------------------------===//
182
183 /// This class defines the Reduction Tree Pass. It provides a framework to
184 /// to implement a reduction pass using a tree structure to keep track of the
185 /// generated reduced variants.
186 class ReductionTreePass : public ReductionTreeBase<ReductionTreePass> {
187 public:
188 ReductionTreePass() = default;
189 ReductionTreePass(const ReductionTreePass &pass) = default;
190
191 LogicalResult initialize(MLIRContext *context) override;
192
193 /// Runs the pass instance in the pass pipeline.
194 void runOnOperation() override;
195
196 private:
197 LogicalResult reduceOp(ModuleOp module, Region ®ion);
198
199 FrozenRewritePatternSet reducerPatterns;
200 };
201
202 } // namespace
203
initialize(MLIRContext * context)204 LogicalResult ReductionTreePass::initialize(MLIRContext *context) {
205 RewritePatternSet patterns(context);
206 ReductionPatternInterfaceCollection reducePatternCollection(context);
207 reducePatternCollection.populateReductionPatterns(patterns);
208 reducerPatterns = std::move(patterns);
209 return success();
210 }
211
runOnOperation()212 void ReductionTreePass::runOnOperation() {
213 Operation *topOperation = getOperation();
214 while (topOperation->getParentOp() != nullptr)
215 topOperation = topOperation->getParentOp();
216 ModuleOp module = cast<ModuleOp>(topOperation);
217
218 SmallVector<Operation *, 8> workList;
219 workList.push_back(getOperation());
220
221 do {
222 Operation *op = workList.pop_back_val();
223
224 for (Region ®ion : op->getRegions())
225 if (!region.empty())
226 if (failed(reduceOp(module, region)))
227 return signalPassFailure();
228
229 for (Region ®ion : op->getRegions())
230 for (Operation &op : region.getOps())
231 if (op.getNumRegions() != 0)
232 workList.push_back(&op);
233 } while (!workList.empty());
234 }
235
reduceOp(ModuleOp module,Region & region)236 LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region ®ion) {
237 Tester test(testerName, testerArgs);
238 switch (traversalModeId) {
239 case TraversalMode::SinglePath:
240 return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
241 module, region, reducerPatterns, test);
242 default:
243 return module.emitError() << "unsupported traversal mode detected";
244 }
245 }
246
createReductionTreePass()247 std::unique_ptr<Pass> mlir::createReductionTreePass() {
248 return std::make_unique<ReductionTreePass>();
249 }
250