1c484c7ddSChia-hung Duan //===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===//
2c484c7ddSChia-hung Duan //
3c484c7ddSChia-hung Duan // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4c484c7ddSChia-hung Duan // See https://llvm.org/LICENSE.txt for license information.
5c484c7ddSChia-hung Duan // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6c484c7ddSChia-hung Duan //
7c484c7ddSChia-hung Duan //===----------------------------------------------------------------------===//
8c484c7ddSChia-hung Duan //
9c484c7ddSChia-hung Duan // This file defines the Reduction Tree Pass class. It provides a framework for
10c484c7ddSChia-hung Duan // the implementation of different reduction passes in the MLIR Reduce tool. It
11c484c7ddSChia-hung Duan // allows for custom specification of the variant generation behavior. It
12c484c7ddSChia-hung Duan // implements methods that define the different possible traversals of the
13c484c7ddSChia-hung Duan // reduction tree.
14c484c7ddSChia-hung Duan //
15c484c7ddSChia-hung Duan //===----------------------------------------------------------------------===//
16c484c7ddSChia-hung Duan
17c484c7ddSChia-hung Duan #include "mlir/IR/DialectInterface.h"
18c484c7ddSChia-hung Duan #include "mlir/IR/OpDefinition.h"
19c484c7ddSChia-hung Duan #include "mlir/Reducer/PassDetail.h"
20c484c7ddSChia-hung Duan #include "mlir/Reducer/Passes.h"
21c484c7ddSChia-hung Duan #include "mlir/Reducer/ReductionNode.h"
22c484c7ddSChia-hung Duan #include "mlir/Reducer/ReductionPatternInterface.h"
23c484c7ddSChia-hung Duan #include "mlir/Reducer/Tester.h"
24c484c7ddSChia-hung Duan #include "mlir/Rewrite/FrozenRewritePatternSet.h"
25c484c7ddSChia-hung Duan #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26c484c7ddSChia-hung Duan
27c484c7ddSChia-hung Duan #include "llvm/ADT/ArrayRef.h"
28c484c7ddSChia-hung Duan #include "llvm/ADT/SmallVector.h"
29c484c7ddSChia-hung Duan #include "llvm/Support/Allocator.h"
30c484c7ddSChia-hung Duan #include "llvm/Support/ManagedStatic.h"
31c484c7ddSChia-hung Duan
32c484c7ddSChia-hung Duan using namespace mlir;
33c484c7ddSChia-hung Duan
34c484c7ddSChia-hung Duan /// We implicitly number each operation in the region and if an operation's
35c484c7ddSChia-hung Duan /// number falls into rangeToKeep, we need to keep it and apply the given
36c484c7ddSChia-hung Duan /// rewrite patterns on it.
applyPatterns(Region & region,const FrozenRewritePatternSet & patterns,ArrayRef<ReductionNode::Range> rangeToKeep,bool eraseOpNotInRange)37c484c7ddSChia-hung Duan static void applyPatterns(Region ®ion,
38c484c7ddSChia-hung Duan const FrozenRewritePatternSet &patterns,
39c484c7ddSChia-hung Duan ArrayRef<ReductionNode::Range> rangeToKeep,
40c484c7ddSChia-hung Duan bool eraseOpNotInRange) {
41c484c7ddSChia-hung Duan std::vector<Operation *> opsNotInRange;
42c484c7ddSChia-hung Duan std::vector<Operation *> opsInRange;
43c484c7ddSChia-hung Duan size_t keepIndex = 0;
44e4853be2SMehdi Amini for (const auto &op : enumerate(region.getOps())) {
45c484c7ddSChia-hung Duan int index = op.index();
46c484c7ddSChia-hung Duan if (keepIndex < rangeToKeep.size() &&
47c484c7ddSChia-hung Duan index == rangeToKeep[keepIndex].second)
48c484c7ddSChia-hung Duan ++keepIndex;
49c484c7ddSChia-hung Duan if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first)
50c484c7ddSChia-hung Duan opsNotInRange.push_back(&op.value());
51c484c7ddSChia-hung Duan else
52c484c7ddSChia-hung Duan opsInRange.push_back(&op.value());
53c484c7ddSChia-hung Duan }
54c484c7ddSChia-hung Duan
55c484c7ddSChia-hung Duan // `applyOpPatternsAndFold` may erase the ops so we can't do the pattern
56c484c7ddSChia-hung Duan // matching in above iteration. Besides, erase op not-in-range may end up in
57c484c7ddSChia-hung Duan // invalid module, so `applyOpPatternsAndFold` should come before that
58c484c7ddSChia-hung Duan // transform.
59c484c7ddSChia-hung Duan for (Operation *op : opsInRange)
60c484c7ddSChia-hung Duan // `applyOpPatternsAndFold` returns whether the op is convered. Omit it
61c484c7ddSChia-hung Duan // because we don't have expectation this reduction will be success or not.
62c484c7ddSChia-hung Duan (void)applyOpPatternsAndFold(op, patterns);
63c484c7ddSChia-hung Duan
64c484c7ddSChia-hung Duan if (eraseOpNotInRange)
65c484c7ddSChia-hung Duan for (Operation *op : opsNotInRange) {
66c484c7ddSChia-hung Duan op->dropAllUses();
67c484c7ddSChia-hung Duan op->erase();
68c484c7ddSChia-hung Duan }
69c484c7ddSChia-hung Duan }
70c484c7ddSChia-hung Duan
71c484c7ddSChia-hung Duan /// We will apply the reducer patterns to the operations in the ranges specified
72c484c7ddSChia-hung Duan /// by ReductionNode. Note that we are not able to remove an operation without
73c484c7ddSChia-hung Duan /// replacing it with another valid operation. However, The validity of module
74c484c7ddSChia-hung Duan /// reduction is based on the Tester provided by the user and that means certain
75c484c7ddSChia-hung Duan /// invalid module is still interested by the use. Thus we provide an
76c484c7ddSChia-hung Duan /// alternative way to remove operations, which is using `eraseOpNotInRange` to
77c484c7ddSChia-hung Duan /// erase the operations not in the range specified by ReductionNode.
78c484c7ddSChia-hung Duan template <typename IteratorType>
findOptimal(ModuleOp module,Region & region,const FrozenRewritePatternSet & patterns,const Tester & test,bool eraseOpNotInRange)791a001dedSChia-hung Duan static LogicalResult findOptimal(ModuleOp module, Region ®ion,
80c484c7ddSChia-hung Duan const FrozenRewritePatternSet &patterns,
81c484c7ddSChia-hung Duan const Tester &test, bool eraseOpNotInRange) {
82c484c7ddSChia-hung Duan std::pair<Tester::Interestingness, size_t> initStatus =
83c484c7ddSChia-hung Duan test.isInteresting(module);
84c484c7ddSChia-hung Duan // While exploring the reduction tree, we always branch from an interesting
85c484c7ddSChia-hung Duan // node. Thus the root node must be interesting.
86c484c7ddSChia-hung Duan if (initStatus.first != Tester::Interestingness::True)
871a001dedSChia-hung Duan return module.emitWarning() << "uninterested module will not be reduced";
88c484c7ddSChia-hung Duan
89c484c7ddSChia-hung Duan llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
90c484c7ddSChia-hung Duan
91c484c7ddSChia-hung Duan std::vector<ReductionNode::Range> ranges{
92c484c7ddSChia-hung Duan {0, std::distance(region.op_begin(), region.op_end())}};
93c484c7ddSChia-hung Duan
94c484c7ddSChia-hung Duan ReductionNode *root = allocator.Allocate();
95*337c937dSMehdi Amini new (root) ReductionNode(nullptr, ranges, allocator);
96c484c7ddSChia-hung Duan // Duplicate the module for root node and locate the region in the copy.
97c484c7ddSChia-hung Duan if (failed(root->initialize(module, region)))
98c484c7ddSChia-hung Duan llvm_unreachable("unexpected initialization failure");
99c484c7ddSChia-hung Duan root->update(initStatus);
100c484c7ddSChia-hung Duan
101c484c7ddSChia-hung Duan ReductionNode *smallestNode = root;
102c484c7ddSChia-hung Duan IteratorType iter(root);
103c484c7ddSChia-hung Duan
104c484c7ddSChia-hung Duan while (iter != IteratorType::end()) {
105c484c7ddSChia-hung Duan ReductionNode ¤tNode = *iter;
106c484c7ddSChia-hung Duan Region &curRegion = currentNode.getRegion();
107c484c7ddSChia-hung Duan
108c484c7ddSChia-hung Duan applyPatterns(curRegion, patterns, currentNode.getRanges(),
109c484c7ddSChia-hung Duan eraseOpNotInRange);
110c484c7ddSChia-hung Duan currentNode.update(test.isInteresting(currentNode.getModule()));
111c484c7ddSChia-hung Duan
112c484c7ddSChia-hung Duan if (currentNode.isInteresting() == Tester::Interestingness::True &&
113c484c7ddSChia-hung Duan currentNode.getSize() < smallestNode->getSize())
114c484c7ddSChia-hung Duan smallestNode = ¤tNode;
115c484c7ddSChia-hung Duan
116c484c7ddSChia-hung Duan ++iter;
117c484c7ddSChia-hung Duan }
118c484c7ddSChia-hung Duan
119c484c7ddSChia-hung Duan // At here, we have found an optimal path to reduce the given region. Retrieve
120c484c7ddSChia-hung Duan // the path and apply the reducer to it.
121c484c7ddSChia-hung Duan SmallVector<ReductionNode *> trace;
122c484c7ddSChia-hung Duan ReductionNode *curNode = smallestNode;
123c484c7ddSChia-hung Duan trace.push_back(curNode);
124c484c7ddSChia-hung Duan while (curNode != root) {
125c484c7ddSChia-hung Duan curNode = curNode->getParent();
126c484c7ddSChia-hung Duan trace.push_back(curNode);
127c484c7ddSChia-hung Duan }
128c484c7ddSChia-hung Duan
129c484c7ddSChia-hung Duan // Reduce the region through the optimal path.
130c484c7ddSChia-hung Duan while (!trace.empty()) {
131c484c7ddSChia-hung Duan ReductionNode *top = trace.pop_back_val();
132c484c7ddSChia-hung Duan applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange);
133c484c7ddSChia-hung Duan }
134c484c7ddSChia-hung Duan
135c484c7ddSChia-hung Duan if (test.isInteresting(module).first != Tester::Interestingness::True)
136c484c7ddSChia-hung Duan llvm::report_fatal_error("Reduced module is not interesting");
137c484c7ddSChia-hung Duan if (test.isInteresting(module).second != smallestNode->getSize())
138c484c7ddSChia-hung Duan llvm::report_fatal_error(
139c484c7ddSChia-hung Duan "Reduced module doesn't have consistent size with smallestNode");
1401a001dedSChia-hung Duan return success();
141c484c7ddSChia-hung Duan }
142c484c7ddSChia-hung Duan
143c484c7ddSChia-hung Duan template <typename IteratorType>
findOptimal(ModuleOp module,Region & region,const FrozenRewritePatternSet & patterns,const Tester & test)1441a001dedSChia-hung Duan static LogicalResult findOptimal(ModuleOp module, Region ®ion,
145c484c7ddSChia-hung Duan const FrozenRewritePatternSet &patterns,
146c484c7ddSChia-hung Duan const Tester &test) {
147c484c7ddSChia-hung Duan // We separate the reduction process into 2 steps, the first one is to erase
148c484c7ddSChia-hung Duan // redundant operations and the second one is to apply the reducer patterns.
149c484c7ddSChia-hung Duan
150c484c7ddSChia-hung Duan // In the first phase, we don't apply any patterns so that we only select the
151c484c7ddSChia-hung Duan // range of operations to keep to the module stay interesting.
1521a001dedSChia-hung Duan if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
1531a001dedSChia-hung Duan /*eraseOpNotInRange=*/true)))
1541a001dedSChia-hung Duan return failure();
155c484c7ddSChia-hung Duan // In the second phase, we suppose that no operation is redundant, so we try
156c484c7ddSChia-hung Duan // to rewrite the operation into simpler form.
1571a001dedSChia-hung Duan return findOptimal<IteratorType>(module, region, patterns, test,
158c484c7ddSChia-hung Duan /*eraseOpNotInRange=*/false);
159c484c7ddSChia-hung Duan }
160c484c7ddSChia-hung Duan
161c484c7ddSChia-hung Duan namespace {
162c484c7ddSChia-hung Duan
163c484c7ddSChia-hung Duan //===----------------------------------------------------------------------===//
164c484c7ddSChia-hung Duan // Reduction Pattern Interface Collection
165c484c7ddSChia-hung Duan //===----------------------------------------------------------------------===//
166c484c7ddSChia-hung Duan
167c484c7ddSChia-hung Duan class ReductionPatternInterfaceCollection
168c484c7ddSChia-hung Duan : public DialectInterfaceCollection<DialectReductionPatternInterface> {
169c484c7ddSChia-hung Duan public:
170c484c7ddSChia-hung Duan using Base::Base;
171c484c7ddSChia-hung Duan
172c484c7ddSChia-hung Duan // Collect the reduce patterns defined by each dialect.
populateReductionPatterns(RewritePatternSet & pattern) const173c484c7ddSChia-hung Duan void populateReductionPatterns(RewritePatternSet &pattern) const {
174c484c7ddSChia-hung Duan for (const DialectReductionPatternInterface &interface : *this)
175c484c7ddSChia-hung Duan interface.populateReductionPatterns(pattern);
176c484c7ddSChia-hung Duan }
177c484c7ddSChia-hung Duan };
178c484c7ddSChia-hung Duan
179c484c7ddSChia-hung Duan //===----------------------------------------------------------------------===//
180c484c7ddSChia-hung Duan // ReductionTreePass
181c484c7ddSChia-hung Duan //===----------------------------------------------------------------------===//
182c484c7ddSChia-hung Duan
183c484c7ddSChia-hung Duan /// This class defines the Reduction Tree Pass. It provides a framework to
184c484c7ddSChia-hung Duan /// to implement a reduction pass using a tree structure to keep track of the
185c484c7ddSChia-hung Duan /// generated reduced variants.
186c484c7ddSChia-hung Duan class ReductionTreePass : public ReductionTreeBase<ReductionTreePass> {
187c484c7ddSChia-hung Duan public:
188c484c7ddSChia-hung Duan ReductionTreePass() = default;
189c484c7ddSChia-hung Duan ReductionTreePass(const ReductionTreePass &pass) = default;
190c484c7ddSChia-hung Duan
191c484c7ddSChia-hung Duan LogicalResult initialize(MLIRContext *context) override;
192c484c7ddSChia-hung Duan
193c484c7ddSChia-hung Duan /// Runs the pass instance in the pass pipeline.
194c484c7ddSChia-hung Duan void runOnOperation() override;
195c484c7ddSChia-hung Duan
196c484c7ddSChia-hung Duan private:
1971a001dedSChia-hung Duan LogicalResult reduceOp(ModuleOp module, Region ®ion);
198c484c7ddSChia-hung Duan
199c484c7ddSChia-hung Duan FrozenRewritePatternSet reducerPatterns;
200c484c7ddSChia-hung Duan };
201c484c7ddSChia-hung Duan
202be0a7e9fSMehdi Amini } // namespace
203c484c7ddSChia-hung Duan
initialize(MLIRContext * context)204c484c7ddSChia-hung Duan LogicalResult ReductionTreePass::initialize(MLIRContext *context) {
205c484c7ddSChia-hung Duan RewritePatternSet patterns(context);
206c484c7ddSChia-hung Duan ReductionPatternInterfaceCollection reducePatternCollection(context);
207c484c7ddSChia-hung Duan reducePatternCollection.populateReductionPatterns(patterns);
208c484c7ddSChia-hung Duan reducerPatterns = std::move(patterns);
209c484c7ddSChia-hung Duan return success();
210c484c7ddSChia-hung Duan }
211c484c7ddSChia-hung Duan
runOnOperation()212c484c7ddSChia-hung Duan void ReductionTreePass::runOnOperation() {
213c484c7ddSChia-hung Duan Operation *topOperation = getOperation();
214c484c7ddSChia-hung Duan while (topOperation->getParentOp() != nullptr)
215c484c7ddSChia-hung Duan topOperation = topOperation->getParentOp();
216c484c7ddSChia-hung Duan ModuleOp module = cast<ModuleOp>(topOperation);
217c484c7ddSChia-hung Duan
218c484c7ddSChia-hung Duan SmallVector<Operation *, 8> workList;
219c484c7ddSChia-hung Duan workList.push_back(getOperation());
220c484c7ddSChia-hung Duan
221c484c7ddSChia-hung Duan do {
222c484c7ddSChia-hung Duan Operation *op = workList.pop_back_val();
223c484c7ddSChia-hung Duan
224c484c7ddSChia-hung Duan for (Region ®ion : op->getRegions())
225c484c7ddSChia-hung Duan if (!region.empty())
2261a001dedSChia-hung Duan if (failed(reduceOp(module, region)))
2271a001dedSChia-hung Duan return signalPassFailure();
228c484c7ddSChia-hung Duan
229c484c7ddSChia-hung Duan for (Region ®ion : op->getRegions())
230c484c7ddSChia-hung Duan for (Operation &op : region.getOps())
231c484c7ddSChia-hung Duan if (op.getNumRegions() != 0)
232c484c7ddSChia-hung Duan workList.push_back(&op);
233c484c7ddSChia-hung Duan } while (!workList.empty());
234c484c7ddSChia-hung Duan }
235c484c7ddSChia-hung Duan
reduceOp(ModuleOp module,Region & region)2361a001dedSChia-hung Duan LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region ®ion) {
237c484c7ddSChia-hung Duan Tester test(testerName, testerArgs);
238c484c7ddSChia-hung Duan switch (traversalModeId) {
239c484c7ddSChia-hung Duan case TraversalMode::SinglePath:
2401a001dedSChia-hung Duan return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
241c484c7ddSChia-hung Duan module, region, reducerPatterns, test);
242c484c7ddSChia-hung Duan default:
2431a001dedSChia-hung Duan return module.emitError() << "unsupported traversal mode detected";
244c484c7ddSChia-hung Duan }
245c484c7ddSChia-hung Duan }
246c484c7ddSChia-hung Duan
createReductionTreePass()247c484c7ddSChia-hung Duan std::unique_ptr<Pass> mlir::createReductionTreePass() {
248c484c7ddSChia-hung Duan return std::make_unique<ReductionTreePass>();
249c484c7ddSChia-hung Duan }
250