1 //===- ExecutionEngine.cpp - MLIR Execution engine and utils --------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file implements the execution engine for MLIR modules based on LLVM Orc 19 // JIT engine. 20 // 21 //===----------------------------------------------------------------------===// 22 #include "mlir/ExecutionEngine/ExecutionEngine.h" 23 #include "mlir/IR/Function.h" 24 #include "mlir/IR/Module.h" 25 #include "mlir/Target/LLVMIR.h" 26 27 #include "llvm/Bitcode/BitcodeReader.h" 28 #include "llvm/Bitcode/BitcodeWriter.h" 29 #include "llvm/ExecutionEngine/ObjectCache.h" 30 #include "llvm/ExecutionEngine/Orc/CompileUtils.h" 31 #include "llvm/ExecutionEngine/Orc/ExecutionUtils.h" 32 #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" 33 #include "llvm/ExecutionEngine/Orc/IRTransformLayer.h" 34 #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" 35 #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" 36 #include "llvm/ExecutionEngine/SectionMemoryManager.h" 37 #include "llvm/IR/IRBuilder.h" 38 #include "llvm/Support/Error.h" 39 #include "llvm/Support/TargetRegistry.h" 40 41 using namespace mlir; 42 using llvm::dbgs; 43 using llvm::Error; 44 using llvm::errs; 45 using llvm::Expected; 46 using llvm::LLVMContext; 47 using llvm::MemoryBuffer; 48 using llvm::MemoryBufferRef; 49 using llvm::Module; 50 using llvm::SectionMemoryManager; 51 using llvm::StringError; 52 using llvm::Triple; 53 using llvm::orc::DynamicLibrarySearchGenerator; 54 using llvm::orc::ExecutionSession; 55 using llvm::orc::IRCompileLayer; 56 using llvm::orc::JITTargetMachineBuilder; 57 using llvm::orc::RTDyldObjectLinkingLayer; 58 using llvm::orc::ThreadSafeModule; 59 using llvm::orc::TMOwningSimpleCompiler; 60 61 // Wrap a string into an llvm::StringError. 62 static inline Error make_string_error(const llvm::Twine &message) { 63 return llvm::make_error<StringError>(message.str(), 64 llvm::inconvertibleErrorCode()); 65 } 66 67 namespace mlir { 68 69 void SimpleObjectCache::notifyObjectCompiled(const Module *M, 70 MemoryBufferRef ObjBuffer) { 71 CachedObjects[M->getModuleIdentifier()] = MemoryBuffer::getMemBufferCopy( 72 ObjBuffer.getBuffer(), ObjBuffer.getBufferIdentifier()); 73 } 74 75 std::unique_ptr<MemoryBuffer> SimpleObjectCache::getObject(const Module *M) { 76 auto I = CachedObjects.find(M->getModuleIdentifier()); 77 if (I == CachedObjects.end()) { 78 dbgs() << "No object for " << M->getModuleIdentifier() 79 << " in cache. Compiling.\n"; 80 return nullptr; 81 } 82 dbgs() << "Object for " << M->getModuleIdentifier() 83 << " loaded from cache.\n"; 84 return MemoryBuffer::getMemBuffer(I->second->getMemBufferRef()); 85 } 86 87 // Setup LLVM target triple from the current machine. 88 bool ExecutionEngine::setupTargetTriple(Module *llvmModule) { 89 // Setup the machine properties from the current architecture. 90 auto targetTriple = llvm::sys::getDefaultTargetTriple(); 91 std::string errorMessage; 92 auto target = llvm::TargetRegistry::lookupTarget(targetTriple, errorMessage); 93 if (!target) { 94 errs() << "NO target: " << errorMessage << "\n"; 95 return true; 96 } 97 auto machine = 98 target->createTargetMachine(targetTriple, "generic", "", {}, {}); 99 llvmModule->setDataLayout(machine->createDataLayout()); 100 llvmModule->setTargetTriple(targetTriple); 101 return false; 102 } 103 104 static std::string makePackedFunctionName(StringRef name) { 105 return "_mlir_" + name.str(); 106 } 107 108 // For each function in the LLVM module, define an interface function that wraps 109 // all the arguments of the original function and all its results into an i8** 110 // pointer to provide a unified invocation interface. 111 void packFunctionArguments(Module *module) { 112 auto &ctx = module->getContext(); 113 llvm::IRBuilder<> builder(ctx); 114 llvm::DenseSet<llvm::Function *> interfaceFunctions; 115 for (auto &func : module->getFunctionList()) { 116 if (func.isDeclaration()) { 117 continue; 118 } 119 if (interfaceFunctions.count(&func)) { 120 continue; 121 } 122 123 // Given a function `foo(<...>)`, define the interface function 124 // `mlir_foo(i8**)`. 125 auto newType = llvm::FunctionType::get( 126 builder.getVoidTy(), builder.getInt8PtrTy()->getPointerTo(), 127 /*isVarArg=*/false); 128 auto newName = makePackedFunctionName(func.getName()); 129 auto funcCst = module->getOrInsertFunction(newName, newType); 130 llvm::Function *interfaceFunc = 131 llvm::cast<llvm::Function>(funcCst.getCallee()); 132 interfaceFunctions.insert(interfaceFunc); 133 134 // Extract the arguments from the type-erased argument list and cast them to 135 // the proper types. 136 auto bb = llvm::BasicBlock::Create(ctx); 137 bb->insertInto(interfaceFunc); 138 builder.SetInsertPoint(bb); 139 llvm::Value *argList = interfaceFunc->arg_begin(); 140 llvm::SmallVector<llvm::Value *, 8> args; 141 args.reserve(llvm::size(func.args())); 142 for (auto &indexedArg : llvm::enumerate(func.args())) { 143 llvm::Value *argIndex = llvm::Constant::getIntegerValue( 144 builder.getInt64Ty(), llvm::APInt(64, indexedArg.index())); 145 llvm::Value *argPtrPtr = builder.CreateGEP(argList, argIndex); 146 llvm::Value *argPtr = builder.CreateLoad(argPtrPtr); 147 argPtr = builder.CreateBitCast( 148 argPtr, indexedArg.value().getType()->getPointerTo()); 149 llvm::Value *arg = builder.CreateLoad(argPtr); 150 args.push_back(arg); 151 } 152 153 // Call the implementation function with the extracted arguments. 154 llvm::Value *result = builder.CreateCall(&func, args); 155 156 // Assuming the result is one value, potentially of type `void`. 157 if (!result->getType()->isVoidTy()) { 158 llvm::Value *retIndex = llvm::Constant::getIntegerValue( 159 builder.getInt64Ty(), llvm::APInt(64, llvm::size(func.args()))); 160 llvm::Value *retPtrPtr = builder.CreateGEP(argList, retIndex); 161 llvm::Value *retPtr = builder.CreateLoad(retPtrPtr); 162 retPtr = builder.CreateBitCast(retPtr, result->getType()->getPointerTo()); 163 builder.CreateStore(result, retPtr); 164 } 165 166 // The interface function returns void. 167 builder.CreateRetVoid(); 168 } 169 } 170 171 Expected<std::unique_ptr<ExecutionEngine>> 172 ExecutionEngine::create(ModuleOp m, 173 std::function<Error(llvm::Module *)> transformer, 174 ArrayRef<StringRef> sharedLibPaths) { 175 auto engine = std::make_unique<ExecutionEngine>(); 176 177 std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext); 178 auto llvmModule = translateModuleToLLVMIR(m); 179 if (!llvmModule) 180 return make_string_error("could not convert to LLVM IR"); 181 // FIXME: the triple should be passed to the translation or dialect conversion 182 // instead of this. Currently, the LLVM module created above has no triple 183 // associated with it. 184 setupTargetTriple(llvmModule.get()); 185 packFunctionArguments(llvmModule.get()); 186 187 // Clone module in a new LLVMContext since translateModuleToLLVMIR buries 188 // ownership too deeply. 189 // TODO(zinenko): Reevaluate model of ownership of LLVMContext in LLVMDialect. 190 SmallVector<char, 1> buffer; 191 { 192 llvm::raw_svector_ostream os(buffer); 193 WriteBitcodeToFile(*llvmModule, os); 194 } 195 llvm::MemoryBufferRef bufferRef(llvm::StringRef(buffer.data(), buffer.size()), 196 "cloned module buffer"); 197 auto expectedModule = parseBitcodeFile(bufferRef, *ctx); 198 if (!expectedModule) 199 return expectedModule.takeError(); 200 std::unique_ptr<Module> deserModule = std::move(*expectedModule); 201 202 // Callback to create the object layer with symbol resolution to current 203 // process and dynamically linked libraries. 204 auto objectLinkingLayerCreator = [&](ExecutionSession &session, 205 const Triple &TT) { 206 auto objectLayer = std::make_unique<RTDyldObjectLinkingLayer>( 207 session, []() { return std::make_unique<SectionMemoryManager>(); }); 208 auto dataLayout = deserModule->getDataLayout(); 209 210 // Resolve symbols that are statically linked in the current process. 211 session.getMainJITDylib().addGenerator( 212 cantFail(DynamicLibrarySearchGenerator::GetForCurrentProcess( 213 dataLayout.getGlobalPrefix()))); 214 215 // Resolve symbols from shared libraries. 216 for (auto libPath : sharedLibPaths) { 217 auto mb = llvm::MemoryBuffer::getFile(libPath); 218 if (!mb) { 219 errs() << "Fail to create MemoryBuffer for: " << libPath << "\n"; 220 continue; 221 } 222 auto &JD = session.createJITDylib(libPath); 223 auto loaded = DynamicLibrarySearchGenerator::Load( 224 libPath.data(), dataLayout.getGlobalPrefix()); 225 if (!loaded) { 226 errs() << "Could not load: " << libPath << "\n"; 227 continue; 228 } 229 JD.addGenerator(std::move(*loaded)); 230 cantFail(objectLayer->add(JD, std::move(mb.get()))); 231 } 232 233 return objectLayer; 234 }; 235 236 // Callback to inspect the cache and recompile on demand. This follows Lang's 237 // LLJITWithObjectCache example. 238 auto compileFunctionCreator = [&](JITTargetMachineBuilder JTMB) 239 -> Expected<IRCompileLayer::CompileFunction> { 240 auto TM = JTMB.createTargetMachine(); 241 if (!TM) 242 return TM.takeError(); 243 return IRCompileLayer::CompileFunction( 244 TMOwningSimpleCompiler(std::move(*TM), engine->cache.get())); 245 }; 246 247 // Create the LLJIT by calling the LLJITBuilder with 2 callbacks. 248 auto jit = 249 cantFail(llvm::orc::LLJITBuilder() 250 .setCompileFunctionCreator(compileFunctionCreator) 251 .setObjectLinkingLayerCreator(objectLinkingLayerCreator) 252 .create()); 253 254 // Add a ThreadSafemodule to the engine and return. 255 ThreadSafeModule tsm(std::move(deserModule), std::move(ctx)); 256 cantFail(jit->addIRModule(std::move(tsm))); 257 engine->jit = std::move(jit); 258 259 return std::move(engine); 260 } 261 262 Expected<void (*)(void **)> ExecutionEngine::lookup(StringRef name) const { 263 auto expectedSymbol = jit->lookup(makePackedFunctionName(name)); 264 if (!expectedSymbol) 265 return expectedSymbol.takeError(); 266 auto rawFPtr = expectedSymbol->getAddress(); 267 auto fptr = reinterpret_cast<void (*)(void **)>(rawFPtr); 268 if (!fptr) 269 return make_string_error("looked up function is null"); 270 return fptr; 271 } 272 273 Error ExecutionEngine::invoke(StringRef name, MutableArrayRef<void *> args) { 274 auto expectedFPtr = lookup(name); 275 if (!expectedFPtr) 276 return expectedFPtr.takeError(); 277 auto fptr = *expectedFPtr; 278 279 (*fptr)(args.data()); 280 281 return Error::success(); 282 } 283 284 } // end namespace mlir 285