1 //===- OptReductionPass.cpp - Optimization Reduction Pass Wrapper ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the Opt Reduction Pass Wrapper. It creates a MLIR pass to
10 // run any optimization pass within it and only replaces the output module with
11 // the transformed version if it is smaller and interesting.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Pass/PassManager.h"
16 #include "mlir/Pass/PassRegistry.h"
17 #include "mlir/Reducer/PassDetail.h"
18 #include "mlir/Reducer/Passes.h"
19 #include "mlir/Reducer/Tester.h"
20 #include "llvm/Support/Debug.h"
21 
22 #define DEBUG_TYPE "mlir-reduce"
23 
24 using namespace mlir;
25 
26 namespace {
27 
28 class OptReductionPass : public OptReductionBase<OptReductionPass> {
29 public:
30   /// Runs the pass instance in the pass pipeline.
31   void runOnOperation() override;
32 };
33 
34 } // end anonymous namespace
35 
36 /// Runs the pass instance in the pass pipeline.
37 void OptReductionPass::runOnOperation() {
38   LLVM_DEBUG(llvm::dbgs() << "\nOptimization Reduction pass: ");
39 
40   Tester test(testerName, testerArgs);
41 
42   ModuleOp module = this->getOperation();
43   ModuleOp moduleVariant = module.clone();
44 
45   PassManager passManager(module.getContext());
46   if (failed(parsePassPipeline(optPass, passManager))) {
47     module.emitError() << "\nfailed to parse pass pipeline";
48     return signalPassFailure();
49   }
50 
51   std::pair<Tester::Interestingness, int> original = test.isInteresting(module);
52   if (original.first != Tester::Interestingness::True) {
53     module.emitError() << "\nthe original input is not interested";
54     return signalPassFailure();
55   }
56 
57   if (failed(passManager.run(moduleVariant))) {
58     module.emitError() << "\nfailed to run pass pipeline";
59     return signalPassFailure();
60   }
61 
62   std::pair<Tester::Interestingness, int> reduced =
63       test.isInteresting(moduleVariant);
64 
65   if (reduced.first == Tester::Interestingness::True &&
66       reduced.second < original.second) {
67     module.getBody()->clear();
68     module.getBody()->getOperations().splice(
69         module.getBody()->begin(), moduleVariant.getBody()->getOperations());
70     LLVM_DEBUG(llvm::dbgs() << "\nSuccessful Transformed version\n\n");
71   } else {
72     LLVM_DEBUG(llvm::dbgs() << "\nUnsuccessful Transformed version\n\n");
73   }
74 
75   moduleVariant->destroy();
76 
77   LLVM_DEBUG(llvm::dbgs() << "Pass Complete\n\n");
78 }
79 
80 std::unique_ptr<Pass> mlir::createOptReductionPass() {
81   return std::make_unique<OptReductionPass>();
82 }
83