1 //===-- LLJITWithOptimizingIRTransform.cpp -- LLJIT with IR optimization --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // In this example we will use an IR transform to optimize a module as it 10 // passes through LLJIT's IRTransformLayer. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/ExecutionEngine/Orc/LLJIT.h" 15 #include "llvm/IR/LegacyPassManager.h" 16 #include "llvm/Support/InitLLVM.h" 17 #include "llvm/Support/TargetSelect.h" 18 #include "llvm/Support/raw_ostream.h" 19 #include "llvm/Transforms/IPO.h" 20 #include "llvm/Transforms/Scalar.h" 21 22 #include "../ExampleModules.h" 23 24 using namespace llvm; 25 using namespace llvm::orc; 26 27 ExitOnError ExitOnErr; 28 29 // Example IR module. 30 // 31 // This IR contains a recursive definition of the factorial function: 32 // 33 // fac(n) | n == 0 = 1 34 // | otherwise = n * fac(n - 1) 35 // 36 // It also contains an entry function which calls the factorial function with 37 // an input value of 5. 38 // 39 // We expect the IR optimization transform that we build below to transform 40 // this into a non-recursive factorial function and an entry function that 41 // returns a constant value of 5!, or 120. 42 43 const llvm::StringRef MainMod = 44 R"( 45 46 define i32 @fac(i32 %n) { 47 entry: 48 %tobool = icmp eq i32 %n, 0 49 br i1 %tobool, label %return, label %if.then 50 51 if.then: ; preds = %entry 52 %arg = add nsw i32 %n, -1 53 %call_result = call i32 @fac(i32 %arg) 54 %result = mul nsw i32 %n, %call_result 55 br label %return 56 57 return: ; preds = %entry, %if.then 58 %final_result = phi i32 [ %result, %if.then ], [ 1, %entry ] 59 ret i32 %final_result 60 } 61 62 define i32 @entry() { 63 entry: 64 %result = call i32 @fac(i32 5) 65 ret i32 %result 66 } 67 68 )"; 69 70 // A function object that creates a simple pass pipeline to apply to each 71 // module as it passes through the IRTransformLayer. 72 class MyOptimizationTransform { 73 public: 74 MyOptimizationTransform() : PM(std::make_unique<legacy::PassManager>()) { 75 PM->add(createTailCallEliminationPass()); 76 PM->add(createFunctionInliningPass()); 77 PM->add(createIndVarSimplifyPass()); 78 PM->add(createCFGSimplificationPass()); 79 } 80 81 Expected<ThreadSafeModule> operator()(ThreadSafeModule TSM, 82 MaterializationResponsibility &R) { 83 TSM.withModuleDo([this](Module &M) { 84 dbgs() << "--- BEFORE OPTIMIZATION ---\n" << M << "\n"; 85 PM->run(M); 86 dbgs() << "--- AFTER OPTIMIZATION ---\n" << M << "\n"; 87 }); 88 return std::move(TSM); 89 } 90 91 private: 92 std::unique_ptr<legacy::PassManager> PM; 93 }; 94 95 int main(int argc, char *argv[]) { 96 // Initialize LLVM. 97 InitLLVM X(argc, argv); 98 99 InitializeNativeTarget(); 100 InitializeNativeTargetAsmPrinter(); 101 102 ExitOnErr.setBanner(std::string(argv[0]) + ": "); 103 104 // (1) Create LLJIT instance. 105 auto J = ExitOnErr(LLJITBuilder().create()); 106 107 // (2) Install transform to optimize modules when they're materialized. 108 J->getIRTransformLayer().setTransform(MyOptimizationTransform()); 109 110 // (3) Add modules. 111 ExitOnErr(J->addIRModule(ExitOnErr(parseExampleModule(MainMod, "MainMod")))); 112 113 // (4) Look up the JIT'd function and call it. 114 auto EntrySym = ExitOnErr(J->lookup("entry")); 115 auto *Entry = (int (*)())EntrySym.getAddress(); 116 117 int Result = Entry(); 118 outs() << "--- Result ---\n" 119 << "entry() = " << Result << "\n"; 120 121 return 0; 122 } 123