1 //===- LowerGPUToHSACO.cpp - Convert GPU kernel to HSACO blob -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass that serializes a gpu module into HSAco blob and
10 // adds that blob as a string attribute of the module.
11 //
12 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/GPU/Passes.h"
14 #include "mlir/IR/Location.h"
15 #include "mlir/IR/MLIRContext.h"
16 
17 #if MLIR_GPU_TO_HSACO_PASS_ENABLE
18 #include "mlir/ExecutionEngine/OptUtils.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Support/FileUtilities.h"
21 #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"
22 #include "mlir/Target/LLVMIR/Export.h"
23 
24 #include "llvm/IR/Constants.h"
25 #include "llvm/IR/GlobalVariable.h"
26 #include "llvm/IR/Module.h"
27 #include "llvm/IRReader/IRReader.h"
28 #include "llvm/Linker/Linker.h"
29 
30 #include "llvm/MC/MCAsmBackend.h"
31 #include "llvm/MC/MCAsmInfo.h"
32 #include "llvm/MC/MCCodeEmitter.h"
33 #include "llvm/MC/MCContext.h"
34 #include "llvm/MC/MCObjectFileInfo.h"
35 #include "llvm/MC/MCObjectWriter.h"
36 #include "llvm/MC/MCParser/MCTargetAsmParser.h"
37 #include "llvm/MC/MCRegisterInfo.h"
38 #include "llvm/MC/MCStreamer.h"
39 #include "llvm/MC/MCSubtargetInfo.h"
40 #include "llvm/MC/TargetRegistry.h"
41 
42 #include "llvm/Support/CommandLine.h"
43 #include "llvm/Support/FileUtilities.h"
44 #include "llvm/Support/Path.h"
45 #include "llvm/Support/Program.h"
46 #include "llvm/Support/SourceMgr.h"
47 #include "llvm/Support/TargetSelect.h"
48 #include "llvm/Support/WithColor.h"
49 
50 #include "llvm/Target/TargetMachine.h"
51 #include "llvm/Target/TargetOptions.h"
52 
53 #include "llvm/Transforms/IPO/Internalize.h"
54 
55 #include "lld/Common/Driver.h"
56 
57 #include <mutex>
58 
59 using namespace mlir;
60 
61 namespace {
62 class SerializeToHsacoPass
63     : public PassWrapper<SerializeToHsacoPass, gpu::SerializeToBlobPass> {
64 public:
65   SerializeToHsacoPass(StringRef triple, StringRef arch, StringRef features,
66                        int optLevel);
67   SerializeToHsacoPass(const SerializeToHsacoPass &other);
68   StringRef getArgument() const override { return "gpu-to-hsaco"; }
69   StringRef getDescription() const override {
70     return "Lower GPU kernel function to HSACO binary annotations";
71   }
72 
73 protected:
74   Option<int> optLevel{
75       *this, "opt-level",
76       llvm::cl::desc("Optimization level for HSACO compilation"),
77       llvm::cl::init(2)};
78 
79   Option<std::string> rocmPath{*this, "rocm-path",
80                                llvm::cl::desc("Path to ROCm install")};
81 
82   // Overload to allow linking in device libs
83   std::unique_ptr<llvm::Module>
84   translateToLLVMIR(llvm::LLVMContext &llvmContext) override;
85 
86   /// Adds LLVM optimization passes
87   LogicalResult optimizeLlvm(llvm::Module &llvmModule,
88                              llvm::TargetMachine &targetMachine) override;
89 
90 private:
91   void getDependentDialects(DialectRegistry &registry) const override;
92 
93   // Loads LLVM bitcode libraries
94   Optional<SmallVector<std::unique_ptr<llvm::Module>, 3>>
95   loadLibraries(SmallVectorImpl<char> &path,
96                 SmallVectorImpl<StringRef> &libraries,
97                 llvm::LLVMContext &context);
98 
99   // Serializes ROCDL to HSACO.
100   std::unique_ptr<std::vector<char>>
101   serializeISA(const std::string &isa) override;
102 
103   std::unique_ptr<SmallVectorImpl<char>> assembleIsa(const std::string &isa);
104   std::unique_ptr<std::vector<char>>
105   createHsaco(const SmallVectorImpl<char> &isaBinary);
106 
107   std::string getRocmPath();
108 };
109 } // namespace
110 
111 SerializeToHsacoPass::SerializeToHsacoPass(const SerializeToHsacoPass &other)
112     : PassWrapper<SerializeToHsacoPass, gpu::SerializeToBlobPass>(other) {}
113 
114 /// Get a user-specified path to ROCm
115 // Tries, in order, the --rocm-path option, the ROCM_PATH environment variable
116 // and a compile-time default
117 std::string SerializeToHsacoPass::getRocmPath() {
118   if (rocmPath.getNumOccurrences() > 0)
119     return rocmPath.getValue();
120 
121   return __DEFAULT_ROCM_PATH__;
122 }
123 
124 // Sets the 'option' to 'value' unless it already has a value.
125 static void maybeSetOption(Pass::Option<std::string> &option,
126                            function_ref<std::string()> getValue) {
127   if (!option.hasValue())
128     option = getValue();
129 }
130 
131 SerializeToHsacoPass::SerializeToHsacoPass(StringRef triple, StringRef arch,
132                                            StringRef features, int optLevel) {
133   maybeSetOption(this->triple, [&triple] { return triple.str(); });
134   maybeSetOption(this->chip, [&arch] { return arch.str(); });
135   maybeSetOption(this->features, [&features] { return features.str(); });
136   if (this->optLevel.getNumOccurrences() == 0)
137     this->optLevel.setValue(optLevel);
138 }
139 
140 void SerializeToHsacoPass::getDependentDialects(
141     DialectRegistry &registry) const {
142   registerROCDLDialectTranslation(registry);
143   gpu::SerializeToBlobPass::getDependentDialects(registry);
144 }
145 
146 Optional<SmallVector<std::unique_ptr<llvm::Module>, 3>>
147 SerializeToHsacoPass::loadLibraries(SmallVectorImpl<char> &path,
148                                     SmallVectorImpl<StringRef> &libraries,
149                                     llvm::LLVMContext &context) {
150   SmallVector<std::unique_ptr<llvm::Module>, 3> ret;
151   size_t dirLength = path.size();
152 
153   if (!llvm::sys::fs::is_directory(path)) {
154     getOperation().emitRemark() << "Bitcode path: " << path
155                                 << " does not exist or is not a directory\n";
156     return llvm::None;
157   }
158 
159   for (const StringRef file : libraries) {
160     llvm::SMDiagnostic error;
161     llvm::sys::path::append(path, file);
162     llvm::StringRef pathRef(path.data(), path.size());
163     std::unique_ptr<llvm::Module> library =
164         llvm::getLazyIRFileModule(pathRef, error, context);
165     path.truncate(dirLength);
166     if (!library) {
167       getOperation().emitError() << "Failed to load library " << file
168                                  << " from " << path << error.getMessage();
169       return llvm::None;
170     }
171     // Some ROCM builds don't strip this like they should
172     if (auto *openclVersion = library->getNamedMetadata("opencl.ocl.version"))
173       library->eraseNamedMetadata(openclVersion);
174     // Stop spamming us with clang version numbers
175     if (auto *ident = library->getNamedMetadata("llvm.ident"))
176       library->eraseNamedMetadata(ident);
177     ret.push_back(std::move(library));
178   }
179 
180   return ret;
181 }
182 
183 std::unique_ptr<llvm::Module>
184 SerializeToHsacoPass::translateToLLVMIR(llvm::LLVMContext &llvmContext) {
185   // MLIR -> LLVM translation
186   std::unique_ptr<llvm::Module> ret =
187       gpu::SerializeToBlobPass::translateToLLVMIR(llvmContext);
188 
189   if (!ret) {
190     getOperation().emitOpError("Module lowering failed");
191     return ret;
192   }
193   // Walk the LLVM module in order to determine if we need to link in device
194   // libs
195   bool needOpenCl = false;
196   bool needOckl = false;
197   bool needOcml = false;
198   for (llvm::Function &f : ret->functions()) {
199     if (f.hasExternalLinkage() && f.hasName() && !f.hasExactDefinition()) {
200       StringRef funcName = f.getName();
201       if ("printf" == funcName)
202         needOpenCl = true;
203       if (funcName.startswith("__ockl_"))
204         needOckl = true;
205       if (funcName.startswith("__ocml_"))
206         needOcml = true;
207     }
208   }
209 
210   if (needOpenCl)
211     needOcml = needOckl = true;
212 
213   // No libraries needed (the typical case)
214   if (!(needOpenCl || needOcml || needOckl))
215     return ret;
216 
217   // Define one of the control constants the ROCm device libraries expect to be
218   // present These constants can either be defined in the module or can be
219   // imported by linking in bitcode that defines the constant. To simplify our
220   // logic, we define the constants into the module we are compiling
221   auto addControlConstant = [&module = *ret](StringRef name, uint32_t value,
222                                              uint32_t bitwidth) {
223     using llvm::GlobalVariable;
224     if (module.getNamedGlobal(name)) {
225       return;
226     }
227     llvm::IntegerType *type =
228         llvm::IntegerType::getIntNTy(module.getContext(), bitwidth);
229     auto *initializer = llvm::ConstantInt::get(type, value, /*isSigned=*/false);
230     auto *constant = new GlobalVariable(
231         module, type,
232         /*isConstant=*/true, GlobalVariable::LinkageTypes::LinkOnceODRLinkage,
233         initializer, name,
234         /*before=*/nullptr,
235         /*threadLocalMode=*/GlobalVariable::ThreadLocalMode::NotThreadLocal,
236         /*addressSpace=*/4);
237     constant->setUnnamedAddr(GlobalVariable::UnnamedAddr::Local);
238     constant->setVisibility(
239         GlobalVariable::VisibilityTypes::ProtectedVisibility);
240     constant->setAlignment(llvm::MaybeAlign(bitwidth / 8));
241   };
242 
243   if (needOcml) {
244     // TODO(kdrewnia): Enable math optimizations once we have support for
245     // `-ffast-math`-like options
246     addControlConstant("__oclc_finite_only_opt", 0, 8);
247     addControlConstant("__oclc_daz_opt", 0, 8);
248     addControlConstant("__oclc_correctly_rounded_sqrt32", 1, 8);
249     addControlConstant("__oclc_unsafe_math_opt", 0, 8);
250   }
251   if (needOcml || needOckl) {
252     addControlConstant("__oclc_wavefrontsize64", 1, 8);
253     StringRef chipSet = this->chip.getValue();
254     if (chipSet.startswith("gfx"))
255       chipSet = chipSet.substr(3);
256     uint32_t minor =
257         llvm::APInt(32, chipSet.substr(chipSet.size() - 2), 16).getZExtValue();
258     uint32_t major = llvm::APInt(32, chipSet.substr(0, chipSet.size() - 2), 10)
259                          .getZExtValue();
260     uint32_t isaNumber = minor + 1000 * major;
261     addControlConstant("__oclc_ISA_version", isaNumber, 32);
262   }
263 
264   // Determine libraries we need to link - order matters due to dependencies
265   llvm::SmallVector<StringRef, 4> libraries;
266   if (needOpenCl)
267     libraries.push_back("opencl.bc");
268   if (needOcml)
269     libraries.push_back("ocml.bc");
270   if (needOckl)
271     libraries.push_back("ockl.bc");
272 
273   Optional<SmallVector<std::unique_ptr<llvm::Module>, 3>> mbModules;
274   std::string theRocmPath = getRocmPath();
275   llvm::SmallString<32> bitcodePath(std::move(theRocmPath));
276   llvm::sys::path::append(bitcodePath, "amdgcn", "bitcode");
277   mbModules = loadLibraries(bitcodePath, libraries, llvmContext);
278 
279   if (!mbModules) {
280     getOperation()
281             .emitWarning("Could not load required device labraries")
282             .attachNote()
283         << "This will probably cause link-time or run-time failures";
284     return ret; // We can still abort here
285   }
286 
287   llvm::Linker linker(*ret);
288   for (std::unique_ptr<llvm::Module> &libModule : mbModules.getValue()) {
289     // This bitcode linking code is substantially similar to what is used in
290     // hip-clang It imports the library functions into the module, allowing LLVM
291     // optimization passes (which must run after linking) to optimize across the
292     // libraries and the module's code. We also only import symbols if they are
293     // referenced by the module or a previous library since there will be no
294     // other source of references to those symbols in this compilation and since
295     // we don't want to bloat the resulting code object.
296     bool err = linker.linkInModule(
297         std::move(libModule), llvm::Linker::Flags::LinkOnlyNeeded,
298         [](llvm::Module &m, const StringSet<> &gvs) {
299           llvm::internalizeModule(m, [&gvs](const llvm::GlobalValue &gv) {
300             return !gv.hasName() || (gvs.count(gv.getName()) == 0);
301           });
302         });
303     // True is linker failure
304     if (err) {
305       getOperation().emitError(
306           "Unrecoverable failure during device library linking.");
307       // We have no guaranties about the state of `ret`, so bail
308       return nullptr;
309     }
310   }
311 
312   // Set amdgpu_hostcall if host calls have been linked, as needed by newer LLVM
313   // FIXME: Is there a way to set this during printf() lowering that makes sense
314   if (ret->getFunction("__ockl_hostcall_internal"))
315     if (!ret->getModuleFlag("amdgpu_hostcall"))
316       ret->addModuleFlag(llvm::Module::Override, "amdgpu_hostcall", 1);
317   return ret;
318 }
319 
320 LogicalResult
321 SerializeToHsacoPass::optimizeLlvm(llvm::Module &llvmModule,
322                                    llvm::TargetMachine &targetMachine) {
323   int optLevel = this->optLevel.getValue();
324   if (optLevel < 0 || optLevel > 3)
325     return getOperation().emitError()
326            << "Invalid HSA optimization level" << optLevel << "\n";
327 
328   targetMachine.setOptLevel(static_cast<llvm::CodeGenOpt::Level>(optLevel));
329 
330   auto transformer =
331       makeOptimizingTransformer(optLevel, /*sizeLevel=*/0, &targetMachine);
332   auto error = transformer(&llvmModule);
333   if (error) {
334     InFlightDiagnostic mlirError = getOperation()->emitError();
335     llvm::handleAllErrors(
336         std::move(error), [&mlirError](const llvm::ErrorInfoBase &ei) {
337           mlirError << "Could not optimize LLVM IR: " << ei.message() << "\n";
338         });
339     return mlirError;
340   }
341   return success();
342 }
343 
344 std::unique_ptr<SmallVectorImpl<char>>
345 SerializeToHsacoPass::assembleIsa(const std::string &isa) {
346   auto loc = getOperation().getLoc();
347 
348   SmallVector<char, 0> result;
349   llvm::raw_svector_ostream os(result);
350 
351   llvm::Triple triple(llvm::Triple::normalize(this->triple));
352   std::string error;
353   const llvm::Target *target =
354       llvm::TargetRegistry::lookupTarget(triple.normalize(), error);
355   if (!target) {
356     emitError(loc, Twine("failed to lookup target: ") + error);
357     return {};
358   }
359 
360   llvm::SourceMgr srcMgr;
361   srcMgr.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(isa),
362                             SMLoc());
363 
364   const llvm::MCTargetOptions mcOptions;
365   std::unique_ptr<llvm::MCRegisterInfo> mri(
366       target->createMCRegInfo(this->triple));
367   std::unique_ptr<llvm::MCAsmInfo> mai(
368       target->createMCAsmInfo(*mri, this->triple, mcOptions));
369   mai->setRelaxELFRelocations(true);
370   std::unique_ptr<llvm::MCSubtargetInfo> sti(
371       target->createMCSubtargetInfo(this->triple, this->chip, this->features));
372 
373   llvm::MCContext ctx(triple, mai.get(), mri.get(), sti.get(), &srcMgr,
374                       &mcOptions);
375   std::unique_ptr<llvm::MCObjectFileInfo> mofi(target->createMCObjectFileInfo(
376       ctx, /*PIC=*/false, /*LargeCodeModel=*/false));
377   ctx.setObjectFileInfo(mofi.get());
378 
379   SmallString<128> cwd;
380   if (!llvm::sys::fs::current_path(cwd))
381     ctx.setCompilationDir(cwd);
382 
383   std::unique_ptr<llvm::MCStreamer> mcStreamer;
384   std::unique_ptr<llvm::MCInstrInfo> mcii(target->createMCInstrInfo());
385 
386   llvm::MCCodeEmitter *ce = target->createMCCodeEmitter(*mcii, *mri, ctx);
387   llvm::MCAsmBackend *mab = target->createMCAsmBackend(*sti, *mri, mcOptions);
388   mcStreamer.reset(target->createMCObjectStreamer(
389       triple, ctx, std::unique_ptr<llvm::MCAsmBackend>(mab),
390       mab->createObjectWriter(os), std::unique_ptr<llvm::MCCodeEmitter>(ce),
391       *sti, mcOptions.MCRelaxAll, mcOptions.MCIncrementalLinkerCompatible,
392       /*DWARFMustBeAtTheEnd*/ false));
393   mcStreamer->setUseAssemblerInfoForParsing(true);
394 
395   std::unique_ptr<llvm::MCAsmParser> parser(
396       createMCAsmParser(srcMgr, ctx, *mcStreamer, *mai));
397   std::unique_ptr<llvm::MCTargetAsmParser> tap(
398       target->createMCAsmParser(*sti, *parser, *mcii, mcOptions));
399 
400   if (!tap) {
401     emitError(loc, "assembler initialization error");
402     return {};
403   }
404 
405   parser->setTargetParser(*tap);
406   parser->Run(false);
407 
408   return std::make_unique<SmallVector<char, 0>>(std::move(result));
409 }
410 
411 std::unique_ptr<std::vector<char>>
412 SerializeToHsacoPass::createHsaco(const SmallVectorImpl<char> &isaBinary) {
413   auto loc = getOperation().getLoc();
414 
415   // Save the ISA binary to a temp file.
416   int tempIsaBinaryFd = -1;
417   SmallString<128> tempIsaBinaryFilename;
418   if (llvm::sys::fs::createTemporaryFile("kernel", "o", tempIsaBinaryFd,
419                                          tempIsaBinaryFilename)) {
420     emitError(loc, "temporary file for ISA binary creation error");
421     return {};
422   }
423   llvm::FileRemover cleanupIsaBinary(tempIsaBinaryFilename);
424   llvm::raw_fd_ostream tempIsaBinaryOs(tempIsaBinaryFd, true);
425   tempIsaBinaryOs << StringRef(isaBinary.data(), isaBinary.size());
426   tempIsaBinaryOs.close();
427 
428   // Create a temp file for HSA code object.
429   int tempHsacoFD = -1;
430   SmallString<128> tempHsacoFilename;
431   if (llvm::sys::fs::createTemporaryFile("kernel", "hsaco", tempHsacoFD,
432                                          tempHsacoFilename)) {
433     emitError(loc, "temporary file for HSA code object creation error");
434     return {};
435   }
436   llvm::FileRemover cleanupHsaco(tempHsacoFilename);
437 
438   {
439     static std::mutex mutex;
440     const std::lock_guard<std::mutex> lock(mutex);
441     // Invoke lld. Expect a true return value from lld.
442     bool r = lld::elf::link({"ld.lld", "-shared", tempIsaBinaryFilename.c_str(),
443                              "-o", tempHsacoFilename.c_str()},
444                             llvm::outs(), llvm::errs(), /*exitEarly=*/false,
445                             /*disableOutput=*/false);
446     // Allow for calling the driver again in the same process.
447     lld::cleanup();
448     if (!r) {
449       emitError(loc, "lld invocation error");
450       return {};
451     }
452   }
453 
454   // Load the HSA code object.
455   auto hsacoFile = openInputFile(tempHsacoFilename);
456   if (!hsacoFile) {
457     emitError(loc, "read HSA code object from temp file error");
458     return {};
459   }
460 
461   StringRef buffer = hsacoFile->getBuffer();
462   return std::make_unique<std::vector<char>>(buffer.begin(), buffer.end());
463 }
464 
465 std::unique_ptr<std::vector<char>>
466 SerializeToHsacoPass::serializeISA(const std::string &isa) {
467   auto isaBinary = assembleIsa(isa);
468   if (!isaBinary)
469     return {};
470   return createHsaco(*isaBinary);
471 }
472 
473 // Register pass to serialize GPU kernel functions to a HSACO binary annotation.
474 void mlir::registerGpuSerializeToHsacoPass() {
475   PassRegistration<SerializeToHsacoPass> registerSerializeToHSACO(
476       [] {
477         // Initialize LLVM AMDGPU backend.
478         LLVMInitializeAMDGPUAsmParser();
479         LLVMInitializeAMDGPUAsmPrinter();
480         LLVMInitializeAMDGPUTarget();
481         LLVMInitializeAMDGPUTargetInfo();
482         LLVMInitializeAMDGPUTargetMC();
483 
484         return std::make_unique<SerializeToHsacoPass>("amdgcn-amd-amdhsa", "",
485                                                       "", 2);
486       });
487 }
488 
489 /// Create an instance of the GPU kernel function to HSAco binary serialization
490 /// pass.
491 std::unique_ptr<Pass> mlir::createGpuSerializeToHsacoPass(StringRef triple,
492                                                           StringRef arch,
493                                                           StringRef features,
494                                                           int optLevel) {
495   return std::make_unique<SerializeToHsacoPass>(triple, arch, features,
496                                                 optLevel);
497 }
498 
499 #else  // MLIR_GPU_TO_HSACO_PASS_ENABLE
500 void mlir::registerGpuSerializeToHsacoPass() {}
501 #endif // MLIR_GPU_TO_HSACO_PASS_ENABLE
502