1 //===- LowerGPUToHSACO.cpp - Convert GPU kernel to HSACO blob -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass that serializes a gpu module into HSAco blob and
10 // adds that blob as a string attribute of the module.
11 //
12 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/GPU/Passes.h"
14 #include "mlir/IR/Location.h"
15 #include "mlir/IR/MLIRContext.h"
16 
17 #if MLIR_GPU_TO_HSACO_PASS_ENABLE
18 #include "mlir/ExecutionEngine/OptUtils.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Support/FileUtilities.h"
21 #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"
22 #include "mlir/Target/LLVMIR/Export.h"
23 
24 #include "llvm/IR/Constants.h"
25 #include "llvm/IR/GlobalVariable.h"
26 #include "llvm/IR/Module.h"
27 #include "llvm/IRReader/IRReader.h"
28 #include "llvm/Linker/Linker.h"
29 
30 #include "llvm/MC/MCAsmBackend.h"
31 #include "llvm/MC/MCAsmInfo.h"
32 #include "llvm/MC/MCCodeEmitter.h"
33 #include "llvm/MC/MCContext.h"
34 #include "llvm/MC/MCObjectFileInfo.h"
35 #include "llvm/MC/MCObjectWriter.h"
36 #include "llvm/MC/MCParser/MCTargetAsmParser.h"
37 #include "llvm/MC/MCStreamer.h"
38 #include "llvm/MC/MCSubtargetInfo.h"
39 #include "llvm/MC/TargetRegistry.h"
40 
41 #include "llvm/Support/CommandLine.h"
42 #include "llvm/Support/FileUtilities.h"
43 #include "llvm/Support/Program.h"
44 #include "llvm/Support/SourceMgr.h"
45 #include "llvm/Support/TargetSelect.h"
46 #include "llvm/Support/WithColor.h"
47 
48 #include "llvm/Target/TargetMachine.h"
49 #include "llvm/Target/TargetOptions.h"
50 
51 #include "llvm/Transforms/IPO/Internalize.h"
52 
53 #include "lld/Common/Driver.h"
54 
55 #include <mutex>
56 
57 using namespace mlir;
58 
59 namespace {
60 class SerializeToHsacoPass
61     : public PassWrapper<SerializeToHsacoPass, gpu::SerializeToBlobPass> {
62 public:
63   SerializeToHsacoPass(StringRef triple, StringRef arch, StringRef features,
64                        int optLevel);
65   SerializeToHsacoPass(const SerializeToHsacoPass &other);
66   StringRef getArgument() const override { return "gpu-to-hsaco"; }
67   StringRef getDescription() const override {
68     return "Lower GPU kernel function to HSACO binary annotations";
69   }
70 
71 protected:
72   Option<int> optLevel{
73       *this, "opt-level",
74       llvm::cl::desc("Optimization level for HSACO compilation"),
75       llvm::cl::init(2)};
76 
77   Option<std::string> rocmPath{*this, "rocm-path",
78                                llvm::cl::desc("Path to ROCm install")};
79 
80   // Overload to allow linking in device libs
81   std::unique_ptr<llvm::Module>
82   translateToLLVMIR(llvm::LLVMContext &llvmContext) override;
83 
84   /// Adds LLVM optimization passes
85   LogicalResult optimizeLlvm(llvm::Module &llvmModule,
86                              llvm::TargetMachine &targetMachine) override;
87 
88 private:
89   void getDependentDialects(DialectRegistry &registry) const override;
90 
91   // Loads LLVM bitcode libraries
92   Optional<SmallVector<std::unique_ptr<llvm::Module>, 3>>
93   loadLibraries(SmallVectorImpl<char> &path,
94                 SmallVectorImpl<StringRef> &libraries,
95                 llvm::LLVMContext &context);
96 
97   // Serializes ROCDL to HSACO.
98   std::unique_ptr<std::vector<char>>
99   serializeISA(const std::string &isa) override;
100 
101   std::unique_ptr<SmallVectorImpl<char>> assembleIsa(const std::string &isa);
102   std::unique_ptr<std::vector<char>>
103   createHsaco(const SmallVectorImpl<char> &isaBinary);
104 
105   std::string getRocmPath();
106 };
107 } // end namespace
108 
109 SerializeToHsacoPass::SerializeToHsacoPass(const SerializeToHsacoPass &other)
110     : PassWrapper<SerializeToHsacoPass, gpu::SerializeToBlobPass>(other) {}
111 
112 /// Get a user-specified path to ROCm
113 // Tries, in order, the --rocm-path option, the ROCM_PATH environment variable
114 // and a compile-time default
115 std::string SerializeToHsacoPass::getRocmPath() {
116   if (rocmPath.getNumOccurrences() > 0)
117     return rocmPath.getValue();
118 
119   return __DEFAULT_ROCM_PATH__;
120 }
121 
122 // Sets the 'option' to 'value' unless it already has a value.
123 static void maybeSetOption(Pass::Option<std::string> &option,
124                            function_ref<std::string()> getValue) {
125   if (!option.hasValue())
126     option = getValue();
127 }
128 
129 SerializeToHsacoPass::SerializeToHsacoPass(StringRef triple, StringRef arch,
130                                            StringRef features, int optLevel) {
131   maybeSetOption(this->triple, [&triple] { return triple.str(); });
132   maybeSetOption(this->chip, [&arch] { return arch.str(); });
133   maybeSetOption(this->features, [&features] { return features.str(); });
134   if (this->optLevel.getNumOccurrences() == 0)
135     this->optLevel.setValue(optLevel);
136 }
137 
138 void SerializeToHsacoPass::getDependentDialects(
139     DialectRegistry &registry) const {
140   registerROCDLDialectTranslation(registry);
141   gpu::SerializeToBlobPass::getDependentDialects(registry);
142 }
143 
144 Optional<SmallVector<std::unique_ptr<llvm::Module>, 3>>
145 SerializeToHsacoPass::loadLibraries(SmallVectorImpl<char> &path,
146                                     SmallVectorImpl<StringRef> &libraries,
147                                     llvm::LLVMContext &context) {
148   SmallVector<std::unique_ptr<llvm::Module>, 3> ret;
149   size_t dirLength = path.size();
150 
151   if (!llvm::sys::fs::is_directory(path)) {
152     getOperation().emitRemark() << "Bitcode path: " << path
153                                 << " does not exist or is not a directory\n";
154     return llvm::None;
155   }
156 
157   for (const StringRef file : libraries) {
158     llvm::SMDiagnostic error;
159     llvm::sys::path::append(path, file);
160     llvm::StringRef pathRef(path.data(), path.size());
161     std::unique_ptr<llvm::Module> library =
162         llvm::getLazyIRFileModule(pathRef, error, context);
163     path.set_size(dirLength);
164     if (!library) {
165       getOperation().emitError() << "Failed to load library " << file
166                                  << " from " << path << error.getMessage();
167       return llvm::None;
168     }
169     // Some ROCM builds don't strip this like they should
170     if (auto *openclVersion = library->getNamedMetadata("opencl.ocl.version"))
171       library->eraseNamedMetadata(openclVersion);
172     // Stop spamming us with clang version numbers
173     if (auto *ident = library->getNamedMetadata("llvm.ident"))
174       library->eraseNamedMetadata(ident);
175     ret.push_back(std::move(library));
176   }
177 
178   return ret;
179 }
180 
181 std::unique_ptr<llvm::Module>
182 SerializeToHsacoPass::translateToLLVMIR(llvm::LLVMContext &llvmContext) {
183   // MLIR -> LLVM translation
184   std::unique_ptr<llvm::Module> ret =
185       gpu::SerializeToBlobPass::translateToLLVMIR(llvmContext);
186 
187   if (!ret) {
188     getOperation().emitOpError("Module lowering failed");
189     return ret;
190   }
191   // Walk the LLVM module in order to determine if we need to link in device
192   // libs
193   bool needOpenCl = false;
194   bool needOckl = false;
195   bool needOcml = false;
196   for (llvm::Function &f : ret->functions()) {
197     if (f.hasExternalLinkage() && f.hasName() && !f.hasExactDefinition()) {
198       StringRef funcName = f.getName();
199       if ("printf" == funcName)
200         needOpenCl = true;
201       if (funcName.startswith("__ockl_"))
202         needOckl = true;
203       if (funcName.startswith("__ocml_"))
204         needOcml = true;
205     }
206   }
207 
208   if (needOpenCl)
209     needOcml = needOckl = true;
210 
211   // No libraries needed (the typical case)
212   if (!(needOpenCl || needOcml || needOckl))
213     return ret;
214 
215   // Define one of the control constants the ROCm device libraries expect to be
216   // present These constants can either be defined in the module or can be
217   // imported by linking in bitcode that defines the constant. To simplify our
218   // logic, we define the constants into the module we are compiling
219   auto addControlConstant = [&module = *ret](StringRef name, uint32_t value,
220                                              uint32_t bitwidth) {
221     using llvm::GlobalVariable;
222     if (module.getNamedGlobal(name)) {
223       return;
224     }
225     llvm::IntegerType *type =
226         llvm::IntegerType::getIntNTy(module.getContext(), bitwidth);
227     auto *initializer = llvm::ConstantInt::get(type, value, /*isSigned=*/false);
228     auto *constant = new GlobalVariable(
229         module, type,
230         /*isConstant=*/true, GlobalVariable::LinkageTypes::LinkOnceODRLinkage,
231         initializer, name,
232         /*before=*/nullptr,
233         /*threadLocalMode=*/GlobalVariable::ThreadLocalMode::NotThreadLocal,
234         /*addressSpace=*/4);
235     constant->setUnnamedAddr(GlobalVariable::UnnamedAddr::Local);
236     constant->setVisibility(
237         GlobalVariable::VisibilityTypes::ProtectedVisibility);
238     constant->setAlignment(llvm::MaybeAlign(bitwidth / 8));
239   };
240 
241   if (needOcml) {
242     // TODO(kdrewnia): Enable math optimizations once we have support for
243     // `-ffast-math`-like options
244     addControlConstant("__oclc_finite_only_opt", 0, 8);
245     addControlConstant("__oclc_daz_opt", 0, 8);
246     addControlConstant("__oclc_correctly_rounded_sqrt32", 1, 8);
247     addControlConstant("__oclc_unsafe_math_opt", 0, 8);
248   }
249   if (needOcml || needOckl) {
250     addControlConstant("__oclc_wavefrontsize64", 1, 8);
251     StringRef chipSet = this->chip.getValue();
252     if (chipSet.startswith("gfx"))
253       chipSet = chipSet.substr(3);
254     uint32_t minor =
255         llvm::APInt(32, chipSet.substr(chipSet.size() - 2), 16).getZExtValue();
256     uint32_t major = llvm::APInt(32, chipSet.substr(0, chipSet.size() - 2), 10)
257                          .getZExtValue();
258     uint32_t isaNumber = minor + 1000 * major;
259     addControlConstant("__oclc_ISA_version", isaNumber, 32);
260   }
261 
262   // Determine libraries we need to link - order matters due to dependencies
263   llvm::SmallVector<StringRef, 4> libraries;
264   if (needOpenCl)
265     libraries.push_back("opencl.bc");
266   if (needOcml)
267     libraries.push_back("ocml.bc");
268   if (needOckl)
269     libraries.push_back("ockl.bc");
270 
271   Optional<SmallVector<std::unique_ptr<llvm::Module>, 3>> mbModules;
272   std::string theRocmPath = getRocmPath();
273   llvm::SmallString<32> bitcodePath(std::move(theRocmPath));
274   llvm::sys::path::append(bitcodePath, "amdgcn", "bitcode");
275   mbModules = loadLibraries(bitcodePath, libraries, llvmContext);
276 
277   if (!mbModules) {
278     getOperation()
279             .emitWarning("Could not load required device labraries")
280             .attachNote()
281         << "This will probably cause link-time or run-time failures";
282     return ret; // We can still abort here
283   }
284 
285   llvm::Linker linker(*ret);
286   for (std::unique_ptr<llvm::Module> &libModule : mbModules.getValue()) {
287     // This bitcode linking code is substantially similar to what is used in
288     // hip-clang It imports the library functions into the module, allowing LLVM
289     // optimization passes (which must run after linking) to optimize across the
290     // libraries and the module's code. We also only import symbols if they are
291     // referenced by the module or a previous library since there will be no
292     // other source of references to those symbols in this compilation and since
293     // we don't want to bloat the resulting code object.
294     bool err = linker.linkInModule(
295         std::move(libModule), llvm::Linker::Flags::LinkOnlyNeeded,
296         [](llvm::Module &m, const StringSet<> &gvs) {
297           llvm::internalizeModule(m, [&gvs](const llvm::GlobalValue &gv) {
298             return !gv.hasName() || (gvs.count(gv.getName()) == 0);
299           });
300         });
301     // True is linker failure
302     if (err) {
303       getOperation().emitError(
304           "Unrecoverable failure during device library linking.");
305       // We have no guaranties about the state of `ret`, so bail
306       return nullptr;
307     }
308   }
309   return ret;
310 }
311 
312 LogicalResult
313 SerializeToHsacoPass::optimizeLlvm(llvm::Module &llvmModule,
314                                    llvm::TargetMachine &targetMachine) {
315   int optLevel = this->optLevel.getValue();
316   if (optLevel < 0 || optLevel > 3)
317     return getOperation().emitError()
318            << "Invalid HSA optimization level" << optLevel << "\n";
319 
320   targetMachine.setOptLevel(static_cast<llvm::CodeGenOpt::Level>(optLevel));
321 
322   auto transformer =
323       makeOptimizingTransformer(optLevel, /*sizeLevel=*/0, &targetMachine);
324   auto error = transformer(&llvmModule);
325   if (error) {
326     InFlightDiagnostic mlirError = getOperation()->emitError();
327     llvm::handleAllErrors(
328         std::move(error), [&mlirError](const llvm::ErrorInfoBase &ei) {
329           mlirError << "Could not optimize LLVM IR: " << ei.message() << "\n";
330         });
331     return mlirError;
332   }
333   return success();
334 }
335 
336 std::unique_ptr<SmallVectorImpl<char>>
337 SerializeToHsacoPass::assembleIsa(const std::string &isa) {
338   auto loc = getOperation().getLoc();
339 
340   SmallVector<char, 0> result;
341   llvm::raw_svector_ostream os(result);
342 
343   llvm::Triple triple(llvm::Triple::normalize(this->triple));
344   std::string error;
345   const llvm::Target *target =
346       llvm::TargetRegistry::lookupTarget(triple.normalize(), error);
347   if (!target) {
348     emitError(loc, Twine("failed to lookup target: ") + error);
349     return {};
350   }
351 
352   llvm::SourceMgr srcMgr;
353   srcMgr.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(isa),
354                             llvm::SMLoc());
355 
356   const llvm::MCTargetOptions mcOptions;
357   std::unique_ptr<llvm::MCRegisterInfo> mri(
358       target->createMCRegInfo(this->triple));
359   std::unique_ptr<llvm::MCAsmInfo> mai(
360       target->createMCAsmInfo(*mri, this->triple, mcOptions));
361   mai->setRelaxELFRelocations(true);
362   std::unique_ptr<llvm::MCSubtargetInfo> sti(
363       target->createMCSubtargetInfo(this->triple, this->chip, this->features));
364 
365   llvm::MCContext ctx(triple, mai.get(), mri.get(), sti.get(), &srcMgr,
366                       &mcOptions);
367   std::unique_ptr<llvm::MCObjectFileInfo> mofi(target->createMCObjectFileInfo(
368       ctx, /*PIC=*/false, /*LargeCodeModel=*/false));
369   ctx.setObjectFileInfo(mofi.get());
370 
371   SmallString<128> cwd;
372   if (!llvm::sys::fs::current_path(cwd))
373     ctx.setCompilationDir(cwd);
374 
375   std::unique_ptr<llvm::MCStreamer> mcStreamer;
376   std::unique_ptr<llvm::MCInstrInfo> mcii(target->createMCInstrInfo());
377 
378   llvm::MCCodeEmitter *ce = target->createMCCodeEmitter(*mcii, *mri, ctx);
379   llvm::MCAsmBackend *mab = target->createMCAsmBackend(*sti, *mri, mcOptions);
380   mcStreamer.reset(target->createMCObjectStreamer(
381       triple, ctx, std::unique_ptr<llvm::MCAsmBackend>(mab),
382       mab->createObjectWriter(os), std::unique_ptr<llvm::MCCodeEmitter>(ce),
383       *sti, mcOptions.MCRelaxAll, mcOptions.MCIncrementalLinkerCompatible,
384       /*DWARFMustBeAtTheEnd*/ false));
385   mcStreamer->setUseAssemblerInfoForParsing(true);
386 
387   std::unique_ptr<llvm::MCAsmParser> parser(
388       createMCAsmParser(srcMgr, ctx, *mcStreamer, *mai));
389   std::unique_ptr<llvm::MCTargetAsmParser> tap(
390       target->createMCAsmParser(*sti, *parser, *mcii, mcOptions));
391 
392   if (!tap) {
393     emitError(loc, "assembler initialization error");
394     return {};
395   }
396 
397   parser->setTargetParser(*tap);
398   parser->Run(false);
399 
400   return std::make_unique<SmallVector<char, 0>>(std::move(result));
401 }
402 
403 std::unique_ptr<std::vector<char>>
404 SerializeToHsacoPass::createHsaco(const SmallVectorImpl<char> &isaBinary) {
405   auto loc = getOperation().getLoc();
406 
407   // Save the ISA binary to a temp file.
408   int tempIsaBinaryFd = -1;
409   SmallString<128> tempIsaBinaryFilename;
410   if (llvm::sys::fs::createTemporaryFile("kernel", "o", tempIsaBinaryFd,
411                                          tempIsaBinaryFilename)) {
412     emitError(loc, "temporary file for ISA binary creation error");
413     return {};
414   }
415   llvm::FileRemover cleanupIsaBinary(tempIsaBinaryFilename);
416   llvm::raw_fd_ostream tempIsaBinaryOs(tempIsaBinaryFd, true);
417   tempIsaBinaryOs << StringRef(isaBinary.data(), isaBinary.size());
418   tempIsaBinaryOs.close();
419 
420   // Create a temp file for HSA code object.
421   int tempHsacoFD = -1;
422   SmallString<128> tempHsacoFilename;
423   if (llvm::sys::fs::createTemporaryFile("kernel", "hsaco", tempHsacoFD,
424                                          tempHsacoFilename)) {
425     emitError(loc, "temporary file for HSA code object creation error");
426     return {};
427   }
428   llvm::FileRemover cleanupHsaco(tempHsacoFilename);
429 
430   {
431     static std::mutex mutex;
432     const std::lock_guard<std::mutex> lock(mutex);
433     // Invoke lld. Expect a true return value from lld.
434     if (!lld::elf::link({"ld.lld", "-shared", tempIsaBinaryFilename.c_str(),
435                          "-o", tempHsacoFilename.c_str()},
436                         /*canEarlyExit=*/false, llvm::outs(), llvm::errs())) {
437       emitError(loc, "lld invocation error");
438       return {};
439     }
440   }
441 
442   // Load the HSA code object.
443   auto hsacoFile = openInputFile(tempHsacoFilename);
444   if (!hsacoFile) {
445     emitError(loc, "read HSA code object from temp file error");
446     return {};
447   }
448 
449   StringRef buffer = hsacoFile->getBuffer();
450   return std::make_unique<std::vector<char>>(buffer.begin(), buffer.end());
451 }
452 
453 std::unique_ptr<std::vector<char>>
454 SerializeToHsacoPass::serializeISA(const std::string &isa) {
455   auto isaBinary = assembleIsa(isa);
456   if (!isaBinary)
457     return {};
458   return createHsaco(*isaBinary);
459 }
460 
461 // Register pass to serialize GPU kernel functions to a HSACO binary annotation.
462 void mlir::registerGpuSerializeToHsacoPass() {
463   PassRegistration<SerializeToHsacoPass> registerSerializeToHSACO(
464       [] {
465         // Initialize LLVM AMDGPU backend.
466         LLVMInitializeAMDGPUAsmParser();
467         LLVMInitializeAMDGPUAsmPrinter();
468         LLVMInitializeAMDGPUTarget();
469         LLVMInitializeAMDGPUTargetInfo();
470         LLVMInitializeAMDGPUTargetMC();
471 
472         return std::make_unique<SerializeToHsacoPass>("amdgcn-amd-amdhsa", "",
473                                                       "", 2);
474       });
475 }
476 #else  // MLIR_GPU_TO_HSACO_PASS_ENABLE
477 void mlir::registerGpuSerializeToHsacoPass() {}
478 #endif // MLIR_GPU_TO_HSACO_PASS_ENABLE
479