1 //===- LowerGPUToHSACO.cpp - Convert GPU kernel to HSACO blob -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass that serializes a gpu module into HSAco blob and 10 // adds that blob as a string attribute of the module. 11 // 12 //===----------------------------------------------------------------------===// 13 #include "mlir/Dialect/GPU/Passes.h" 14 #include "mlir/IR/Location.h" 15 #include "mlir/IR/MLIRContext.h" 16 17 #if MLIR_GPU_TO_HSACO_PASS_ENABLE 18 #include "mlir/ExecutionEngine/OptUtils.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Support/FileUtilities.h" 21 #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" 22 #include "mlir/Target/LLVMIR/Export.h" 23 24 #include "llvm/IR/Constants.h" 25 #include "llvm/IR/GlobalVariable.h" 26 #include "llvm/IR/Module.h" 27 #include "llvm/IRReader/IRReader.h" 28 #include "llvm/Linker/Linker.h" 29 30 #include "llvm/MC/MCAsmBackend.h" 31 #include "llvm/MC/MCAsmInfo.h" 32 #include "llvm/MC/MCCodeEmitter.h" 33 #include "llvm/MC/MCContext.h" 34 #include "llvm/MC/MCObjectFileInfo.h" 35 #include "llvm/MC/MCObjectWriter.h" 36 #include "llvm/MC/MCParser/MCTargetAsmParser.h" 37 #include "llvm/MC/MCStreamer.h" 38 #include "llvm/MC/MCSubtargetInfo.h" 39 #include "llvm/MC/TargetRegistry.h" 40 41 #include "llvm/Support/CommandLine.h" 42 #include "llvm/Support/FileUtilities.h" 43 #include "llvm/Support/Program.h" 44 #include "llvm/Support/SourceMgr.h" 45 #include "llvm/Support/TargetSelect.h" 46 #include "llvm/Support/WithColor.h" 47 48 #include "llvm/Target/TargetMachine.h" 49 #include "llvm/Target/TargetOptions.h" 50 51 #include "llvm/Transforms/IPO/Internalize.h" 52 53 #include "lld/Common/Driver.h" 54 55 #include <mutex> 56 57 using namespace mlir; 58 59 namespace { 60 class SerializeToHsacoPass 61 : public PassWrapper<SerializeToHsacoPass, gpu::SerializeToBlobPass> { 62 public: 63 SerializeToHsacoPass(StringRef triple, StringRef arch, StringRef features, 64 int optLevel); 65 SerializeToHsacoPass(const SerializeToHsacoPass &other); 66 StringRef getArgument() const override { return "gpu-to-hsaco"; } 67 StringRef getDescription() const override { 68 return "Lower GPU kernel function to HSACO binary annotations"; 69 } 70 71 protected: 72 Option<int> optLevel{ 73 *this, "opt-level", 74 llvm::cl::desc("Optimization level for HSACO compilation"), 75 llvm::cl::init(2)}; 76 77 Option<std::string> rocmPath{*this, "rocm-path", 78 llvm::cl::desc("Path to ROCm install")}; 79 80 // Overload to allow linking in device libs 81 std::unique_ptr<llvm::Module> 82 translateToLLVMIR(llvm::LLVMContext &llvmContext) override; 83 84 /// Adds LLVM optimization passes 85 LogicalResult optimizeLlvm(llvm::Module &llvmModule, 86 llvm::TargetMachine &targetMachine) override; 87 88 private: 89 void getDependentDialects(DialectRegistry ®istry) const override; 90 91 // Loads LLVM bitcode libraries 92 Optional<SmallVector<std::unique_ptr<llvm::Module>, 3>> 93 loadLibraries(SmallVectorImpl<char> &path, 94 SmallVectorImpl<StringRef> &libraries, 95 llvm::LLVMContext &context); 96 97 // Serializes ROCDL to HSACO. 98 std::unique_ptr<std::vector<char>> 99 serializeISA(const std::string &isa) override; 100 101 std::unique_ptr<SmallVectorImpl<char>> assembleIsa(const std::string &isa); 102 std::unique_ptr<std::vector<char>> 103 createHsaco(const SmallVectorImpl<char> &isaBinary); 104 105 std::string getRocmPath(); 106 }; 107 } // end namespace 108 109 SerializeToHsacoPass::SerializeToHsacoPass(const SerializeToHsacoPass &other) 110 : PassWrapper<SerializeToHsacoPass, gpu::SerializeToBlobPass>(other) {} 111 112 /// Get a user-specified path to ROCm 113 // Tries, in order, the --rocm-path option, the ROCM_PATH environment variable 114 // and a compile-time default 115 std::string SerializeToHsacoPass::getRocmPath() { 116 if (rocmPath.getNumOccurrences() > 0) 117 return rocmPath.getValue(); 118 119 return __DEFAULT_ROCM_PATH__; 120 } 121 122 // Sets the 'option' to 'value' unless it already has a value. 123 static void maybeSetOption(Pass::Option<std::string> &option, 124 function_ref<std::string()> getValue) { 125 if (!option.hasValue()) 126 option = getValue(); 127 } 128 129 SerializeToHsacoPass::SerializeToHsacoPass(StringRef triple, StringRef arch, 130 StringRef features, int optLevel) { 131 maybeSetOption(this->triple, [&triple] { return triple.str(); }); 132 maybeSetOption(this->chip, [&arch] { return arch.str(); }); 133 maybeSetOption(this->features, [&features] { return features.str(); }); 134 if (this->optLevel.getNumOccurrences() == 0) 135 this->optLevel.setValue(optLevel); 136 } 137 138 void SerializeToHsacoPass::getDependentDialects( 139 DialectRegistry ®istry) const { 140 registerROCDLDialectTranslation(registry); 141 gpu::SerializeToBlobPass::getDependentDialects(registry); 142 } 143 144 Optional<SmallVector<std::unique_ptr<llvm::Module>, 3>> 145 SerializeToHsacoPass::loadLibraries(SmallVectorImpl<char> &path, 146 SmallVectorImpl<StringRef> &libraries, 147 llvm::LLVMContext &context) { 148 SmallVector<std::unique_ptr<llvm::Module>, 3> ret; 149 size_t dirLength = path.size(); 150 151 if (!llvm::sys::fs::is_directory(path)) { 152 getOperation().emitRemark() << "Bitcode path: " << path 153 << " does not exist or is not a directory\n"; 154 return llvm::None; 155 } 156 157 for (const StringRef file : libraries) { 158 llvm::SMDiagnostic error; 159 llvm::sys::path::append(path, file); 160 llvm::StringRef pathRef(path.data(), path.size()); 161 std::unique_ptr<llvm::Module> library = 162 llvm::getLazyIRFileModule(pathRef, error, context); 163 path.set_size(dirLength); 164 if (!library) { 165 getOperation().emitError() << "Failed to load library " << file 166 << " from " << path << error.getMessage(); 167 return llvm::None; 168 } 169 // Some ROCM builds don't strip this like they should 170 if (auto *openclVersion = library->getNamedMetadata("opencl.ocl.version")) 171 library->eraseNamedMetadata(openclVersion); 172 // Stop spamming us with clang version numbers 173 if (auto *ident = library->getNamedMetadata("llvm.ident")) 174 library->eraseNamedMetadata(ident); 175 ret.push_back(std::move(library)); 176 } 177 178 return ret; 179 } 180 181 std::unique_ptr<llvm::Module> 182 SerializeToHsacoPass::translateToLLVMIR(llvm::LLVMContext &llvmContext) { 183 // MLIR -> LLVM translation 184 std::unique_ptr<llvm::Module> ret = 185 gpu::SerializeToBlobPass::translateToLLVMIR(llvmContext); 186 187 if (!ret) { 188 getOperation().emitOpError("Module lowering failed"); 189 return ret; 190 } 191 // Walk the LLVM module in order to determine if we need to link in device 192 // libs 193 bool needOpenCl = false; 194 bool needOckl = false; 195 bool needOcml = false; 196 for (llvm::Function &f : ret->functions()) { 197 if (f.hasExternalLinkage() && f.hasName() && !f.hasExactDefinition()) { 198 StringRef funcName = f.getName(); 199 if ("printf" == funcName) 200 needOpenCl = true; 201 if (funcName.startswith("__ockl_")) 202 needOckl = true; 203 if (funcName.startswith("__ocml_")) 204 needOcml = true; 205 } 206 } 207 208 if (needOpenCl) 209 needOcml = needOckl = true; 210 211 // No libraries needed (the typical case) 212 if (!(needOpenCl || needOcml || needOckl)) 213 return ret; 214 215 // Define one of the control constants the ROCm device libraries expect to be 216 // present These constants can either be defined in the module or can be 217 // imported by linking in bitcode that defines the constant. To simplify our 218 // logic, we define the constants into the module we are compiling 219 auto addControlConstant = [&module = *ret](StringRef name, uint32_t value, 220 uint32_t bitwidth) { 221 using llvm::GlobalVariable; 222 if (module.getNamedGlobal(name)) { 223 return; 224 } 225 llvm::IntegerType *type = 226 llvm::IntegerType::getIntNTy(module.getContext(), bitwidth); 227 auto *initializer = llvm::ConstantInt::get(type, value, /*isSigned=*/false); 228 auto *constant = new GlobalVariable( 229 module, type, 230 /*isConstant=*/true, GlobalVariable::LinkageTypes::LinkOnceODRLinkage, 231 initializer, name, 232 /*before=*/nullptr, 233 /*threadLocalMode=*/GlobalVariable::ThreadLocalMode::NotThreadLocal, 234 /*addressSpace=*/4); 235 constant->setUnnamedAddr(GlobalVariable::UnnamedAddr::Local); 236 constant->setVisibility( 237 GlobalVariable::VisibilityTypes::ProtectedVisibility); 238 constant->setAlignment(llvm::MaybeAlign(bitwidth / 8)); 239 }; 240 241 if (needOcml) { 242 // TODO(kdrewnia): Enable math optimizations once we have support for 243 // `-ffast-math`-like options 244 addControlConstant("__oclc_finite_only_opt", 0, 8); 245 addControlConstant("__oclc_daz_opt", 0, 8); 246 addControlConstant("__oclc_correctly_rounded_sqrt32", 1, 8); 247 addControlConstant("__oclc_unsafe_math_opt", 0, 8); 248 } 249 if (needOcml || needOckl) { 250 addControlConstant("__oclc_wavefrontsize64", 1, 8); 251 StringRef chipSet = this->chip.getValue(); 252 if (chipSet.startswith("gfx")) 253 chipSet = chipSet.substr(3); 254 uint32_t minor = 255 llvm::APInt(32, chipSet.substr(chipSet.size() - 2), 16).getZExtValue(); 256 uint32_t major = llvm::APInt(32, chipSet.substr(0, chipSet.size() - 2), 10) 257 .getZExtValue(); 258 uint32_t isaNumber = minor + 1000 * major; 259 addControlConstant("__oclc_ISA_version", isaNumber, 32); 260 } 261 262 // Determine libraries we need to link - order matters due to dependencies 263 llvm::SmallVector<StringRef, 4> libraries; 264 if (needOpenCl) 265 libraries.push_back("opencl.bc"); 266 if (needOcml) 267 libraries.push_back("ocml.bc"); 268 if (needOckl) 269 libraries.push_back("ockl.bc"); 270 271 Optional<SmallVector<std::unique_ptr<llvm::Module>, 3>> mbModules; 272 std::string theRocmPath = getRocmPath(); 273 llvm::SmallString<32> bitcodePath(std::move(theRocmPath)); 274 llvm::sys::path::append(bitcodePath, "amdgcn", "bitcode"); 275 mbModules = loadLibraries(bitcodePath, libraries, llvmContext); 276 277 if (!mbModules) { 278 getOperation() 279 .emitWarning("Could not load required device labraries") 280 .attachNote() 281 << "This will probably cause link-time or run-time failures"; 282 return ret; // We can still abort here 283 } 284 285 llvm::Linker linker(*ret); 286 for (std::unique_ptr<llvm::Module> &libModule : mbModules.getValue()) { 287 // This bitcode linking code is substantially similar to what is used in 288 // hip-clang It imports the library functions into the module, allowing LLVM 289 // optimization passes (which must run after linking) to optimize across the 290 // libraries and the module's code. We also only import symbols if they are 291 // referenced by the module or a previous library since there will be no 292 // other source of references to those symbols in this compilation and since 293 // we don't want to bloat the resulting code object. 294 bool err = linker.linkInModule( 295 std::move(libModule), llvm::Linker::Flags::LinkOnlyNeeded, 296 [](llvm::Module &m, const StringSet<> &gvs) { 297 llvm::internalizeModule(m, [&gvs](const llvm::GlobalValue &gv) { 298 return !gv.hasName() || (gvs.count(gv.getName()) == 0); 299 }); 300 }); 301 // True is linker failure 302 if (err) { 303 getOperation().emitError( 304 "Unrecoverable failure during device library linking."); 305 // We have no guaranties about the state of `ret`, so bail 306 return nullptr; 307 } 308 } 309 return ret; 310 } 311 312 LogicalResult 313 SerializeToHsacoPass::optimizeLlvm(llvm::Module &llvmModule, 314 llvm::TargetMachine &targetMachine) { 315 int optLevel = this->optLevel.getValue(); 316 if (optLevel < 0 || optLevel > 3) 317 return getOperation().emitError() 318 << "Invalid HSA optimization level" << optLevel << "\n"; 319 320 targetMachine.setOptLevel(static_cast<llvm::CodeGenOpt::Level>(optLevel)); 321 322 auto transformer = 323 makeOptimizingTransformer(optLevel, /*sizeLevel=*/0, &targetMachine); 324 auto error = transformer(&llvmModule); 325 if (error) { 326 InFlightDiagnostic mlirError = getOperation()->emitError(); 327 llvm::handleAllErrors( 328 std::move(error), [&mlirError](const llvm::ErrorInfoBase &ei) { 329 mlirError << "Could not optimize LLVM IR: " << ei.message() << "\n"; 330 }); 331 return mlirError; 332 } 333 return success(); 334 } 335 336 std::unique_ptr<SmallVectorImpl<char>> 337 SerializeToHsacoPass::assembleIsa(const std::string &isa) { 338 auto loc = getOperation().getLoc(); 339 340 SmallVector<char, 0> result; 341 llvm::raw_svector_ostream os(result); 342 343 llvm::Triple triple(llvm::Triple::normalize(this->triple)); 344 std::string error; 345 const llvm::Target *target = 346 llvm::TargetRegistry::lookupTarget(triple.normalize(), error); 347 if (!target) { 348 emitError(loc, Twine("failed to lookup target: ") + error); 349 return {}; 350 } 351 352 llvm::SourceMgr srcMgr; 353 srcMgr.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(isa), 354 llvm::SMLoc()); 355 356 const llvm::MCTargetOptions mcOptions; 357 std::unique_ptr<llvm::MCRegisterInfo> mri( 358 target->createMCRegInfo(this->triple)); 359 std::unique_ptr<llvm::MCAsmInfo> mai( 360 target->createMCAsmInfo(*mri, this->triple, mcOptions)); 361 mai->setRelaxELFRelocations(true); 362 std::unique_ptr<llvm::MCSubtargetInfo> sti( 363 target->createMCSubtargetInfo(this->triple, this->chip, this->features)); 364 365 llvm::MCContext ctx(triple, mai.get(), mri.get(), sti.get(), &srcMgr, 366 &mcOptions); 367 std::unique_ptr<llvm::MCObjectFileInfo> mofi(target->createMCObjectFileInfo( 368 ctx, /*PIC=*/false, /*LargeCodeModel=*/false)); 369 ctx.setObjectFileInfo(mofi.get()); 370 371 SmallString<128> cwd; 372 if (!llvm::sys::fs::current_path(cwd)) 373 ctx.setCompilationDir(cwd); 374 375 std::unique_ptr<llvm::MCStreamer> mcStreamer; 376 std::unique_ptr<llvm::MCInstrInfo> mcii(target->createMCInstrInfo()); 377 378 llvm::MCCodeEmitter *ce = target->createMCCodeEmitter(*mcii, *mri, ctx); 379 llvm::MCAsmBackend *mab = target->createMCAsmBackend(*sti, *mri, mcOptions); 380 mcStreamer.reset(target->createMCObjectStreamer( 381 triple, ctx, std::unique_ptr<llvm::MCAsmBackend>(mab), 382 mab->createObjectWriter(os), std::unique_ptr<llvm::MCCodeEmitter>(ce), 383 *sti, mcOptions.MCRelaxAll, mcOptions.MCIncrementalLinkerCompatible, 384 /*DWARFMustBeAtTheEnd*/ false)); 385 mcStreamer->setUseAssemblerInfoForParsing(true); 386 387 std::unique_ptr<llvm::MCAsmParser> parser( 388 createMCAsmParser(srcMgr, ctx, *mcStreamer, *mai)); 389 std::unique_ptr<llvm::MCTargetAsmParser> tap( 390 target->createMCAsmParser(*sti, *parser, *mcii, mcOptions)); 391 392 if (!tap) { 393 emitError(loc, "assembler initialization error"); 394 return {}; 395 } 396 397 parser->setTargetParser(*tap); 398 parser->Run(false); 399 400 return std::make_unique<SmallVector<char, 0>>(std::move(result)); 401 } 402 403 std::unique_ptr<std::vector<char>> 404 SerializeToHsacoPass::createHsaco(const SmallVectorImpl<char> &isaBinary) { 405 auto loc = getOperation().getLoc(); 406 407 // Save the ISA binary to a temp file. 408 int tempIsaBinaryFd = -1; 409 SmallString<128> tempIsaBinaryFilename; 410 if (llvm::sys::fs::createTemporaryFile("kernel", "o", tempIsaBinaryFd, 411 tempIsaBinaryFilename)) { 412 emitError(loc, "temporary file for ISA binary creation error"); 413 return {}; 414 } 415 llvm::FileRemover cleanupIsaBinary(tempIsaBinaryFilename); 416 llvm::raw_fd_ostream tempIsaBinaryOs(tempIsaBinaryFd, true); 417 tempIsaBinaryOs << StringRef(isaBinary.data(), isaBinary.size()); 418 tempIsaBinaryOs.close(); 419 420 // Create a temp file for HSA code object. 421 int tempHsacoFD = -1; 422 SmallString<128> tempHsacoFilename; 423 if (llvm::sys::fs::createTemporaryFile("kernel", "hsaco", tempHsacoFD, 424 tempHsacoFilename)) { 425 emitError(loc, "temporary file for HSA code object creation error"); 426 return {}; 427 } 428 llvm::FileRemover cleanupHsaco(tempHsacoFilename); 429 430 { 431 static std::mutex mutex; 432 const std::lock_guard<std::mutex> lock(mutex); 433 // Invoke lld. Expect a true return value from lld. 434 if (!lld::elf::link({"ld.lld", "-shared", tempIsaBinaryFilename.c_str(), 435 "-o", tempHsacoFilename.c_str()}, 436 /*canEarlyExit=*/false, llvm::outs(), llvm::errs())) { 437 emitError(loc, "lld invocation error"); 438 return {}; 439 } 440 } 441 442 // Load the HSA code object. 443 auto hsacoFile = openInputFile(tempHsacoFilename); 444 if (!hsacoFile) { 445 emitError(loc, "read HSA code object from temp file error"); 446 return {}; 447 } 448 449 StringRef buffer = hsacoFile->getBuffer(); 450 return std::make_unique<std::vector<char>>(buffer.begin(), buffer.end()); 451 } 452 453 std::unique_ptr<std::vector<char>> 454 SerializeToHsacoPass::serializeISA(const std::string &isa) { 455 auto isaBinary = assembleIsa(isa); 456 if (!isaBinary) 457 return {}; 458 return createHsaco(*isaBinary); 459 } 460 461 // Register pass to serialize GPU kernel functions to a HSACO binary annotation. 462 void mlir::registerGpuSerializeToHsacoPass() { 463 PassRegistration<SerializeToHsacoPass> registerSerializeToHSACO( 464 [] { 465 // Initialize LLVM AMDGPU backend. 466 LLVMInitializeAMDGPUAsmParser(); 467 LLVMInitializeAMDGPUAsmPrinter(); 468 LLVMInitializeAMDGPUTarget(); 469 LLVMInitializeAMDGPUTargetInfo(); 470 LLVMInitializeAMDGPUTargetMC(); 471 472 return std::make_unique<SerializeToHsacoPass>("amdgcn-amd-amdhsa", "", 473 "", 2); 474 }); 475 } 476 #else // MLIR_GPU_TO_HSACO_PASS_ENABLE 477 void mlir::registerGpuSerializeToHsacoPass() {} 478 #endif // MLIR_GPU_TO_HSACO_PASS_ENABLE 479