1 //===- LowerGPUToHSACO.cpp - Convert GPU kernel to HSACO blob -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass that serializes a gpu module into HSAco blob and
10 // adds that blob as a string attribute of the module.
11 //
12 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/GPU/Passes.h"
14 #include "mlir/IR/Location.h"
15 #include "mlir/IR/MLIRContext.h"
16 
17 #if MLIR_GPU_TO_HSACO_PASS_ENABLE
18 #include "mlir/ExecutionEngine/OptUtils.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Support/FileUtilities.h"
21 #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"
22 #include "mlir/Target/LLVMIR/Export.h"
23 
24 #include "llvm/IR/Constants.h"
25 #include "llvm/IR/GlobalVariable.h"
26 #include "llvm/IR/Module.h"
27 #include "llvm/IRReader/IRReader.h"
28 #include "llvm/Linker/Linker.h"
29 
30 #include "llvm/MC/MCAsmBackend.h"
31 #include "llvm/MC/MCAsmInfo.h"
32 #include "llvm/MC/MCCodeEmitter.h"
33 #include "llvm/MC/MCContext.h"
34 #include "llvm/MC/MCInstrInfo.h"
35 #include "llvm/MC/MCObjectFileInfo.h"
36 #include "llvm/MC/MCObjectWriter.h"
37 #include "llvm/MC/MCParser/MCTargetAsmParser.h"
38 #include "llvm/MC/MCRegisterInfo.h"
39 #include "llvm/MC/MCStreamer.h"
40 #include "llvm/MC/MCSubtargetInfo.h"
41 #include "llvm/MC/TargetRegistry.h"
42 
43 #include "llvm/Support/CommandLine.h"
44 #include "llvm/Support/FileUtilities.h"
45 #include "llvm/Support/Path.h"
46 #include "llvm/Support/Program.h"
47 #include "llvm/Support/SourceMgr.h"
48 #include "llvm/Support/TargetSelect.h"
49 #include "llvm/Support/WithColor.h"
50 
51 #include "llvm/Target/TargetMachine.h"
52 #include "llvm/Target/TargetOptions.h"
53 
54 #include "llvm/Transforms/IPO/Internalize.h"
55 
56 #include <mutex>
57 
58 using namespace mlir;
59 
60 namespace {
61 class SerializeToHsacoPass
62     : public PassWrapper<SerializeToHsacoPass, gpu::SerializeToBlobPass> {
63 public:
64   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SerializeToHsacoPass)
65 
66   SerializeToHsacoPass(StringRef triple, StringRef arch, StringRef features,
67                        int optLevel);
68   SerializeToHsacoPass(const SerializeToHsacoPass &other);
69   StringRef getArgument() const override { return "gpu-to-hsaco"; }
70   StringRef getDescription() const override {
71     return "Lower GPU kernel function to HSACO binary annotations";
72   }
73 
74 protected:
75   Option<int> optLevel{
76       *this, "opt-level",
77       llvm::cl::desc("Optimization level for HSACO compilation"),
78       llvm::cl::init(2)};
79 
80   Option<std::string> rocmPath{*this, "rocm-path",
81                                llvm::cl::desc("Path to ROCm install")};
82 
83   // Overload to allow linking in device libs
84   std::unique_ptr<llvm::Module>
85   translateToLLVMIR(llvm::LLVMContext &llvmContext) override;
86 
87   /// Adds LLVM optimization passes
88   LogicalResult optimizeLlvm(llvm::Module &llvmModule,
89                              llvm::TargetMachine &targetMachine) override;
90 
91 private:
92   void getDependentDialects(DialectRegistry &registry) const override;
93 
94   // Loads LLVM bitcode libraries
95   Optional<SmallVector<std::unique_ptr<llvm::Module>, 3>>
96   loadLibraries(SmallVectorImpl<char> &path,
97                 SmallVectorImpl<StringRef> &libraries,
98                 llvm::LLVMContext &context);
99 
100   // Serializes ROCDL to HSACO.
101   std::unique_ptr<std::vector<char>>
102   serializeISA(const std::string &isa) override;
103 
104   std::unique_ptr<SmallVectorImpl<char>> assembleIsa(const std::string &isa);
105   std::unique_ptr<std::vector<char>>
106   createHsaco(const SmallVectorImpl<char> &isaBinary);
107 
108   std::string getRocmPath();
109 };
110 } // namespace
111 
112 SerializeToHsacoPass::SerializeToHsacoPass(const SerializeToHsacoPass &other)
113     : PassWrapper<SerializeToHsacoPass, gpu::SerializeToBlobPass>(other) {}
114 
115 /// Get a user-specified path to ROCm
116 // Tries, in order, the --rocm-path option, the ROCM_PATH environment variable
117 // and a compile-time default
118 std::string SerializeToHsacoPass::getRocmPath() {
119   if (rocmPath.getNumOccurrences() > 0)
120     return rocmPath.getValue();
121 
122   return __DEFAULT_ROCM_PATH__;
123 }
124 
125 // Sets the 'option' to 'value' unless it already has a value.
126 static void maybeSetOption(Pass::Option<std::string> &option,
127                            function_ref<std::string()> getValue) {
128   if (!option.hasValue())
129     option = getValue();
130 }
131 
132 SerializeToHsacoPass::SerializeToHsacoPass(StringRef triple, StringRef arch,
133                                            StringRef features, int optLevel) {
134   maybeSetOption(this->triple, [&triple] { return triple.str(); });
135   maybeSetOption(this->chip, [&arch] { return arch.str(); });
136   maybeSetOption(this->features, [&features] { return features.str(); });
137   if (this->optLevel.getNumOccurrences() == 0)
138     this->optLevel.setValue(optLevel);
139 }
140 
141 void SerializeToHsacoPass::getDependentDialects(
142     DialectRegistry &registry) const {
143   registerROCDLDialectTranslation(registry);
144   gpu::SerializeToBlobPass::getDependentDialects(registry);
145 }
146 
147 Optional<SmallVector<std::unique_ptr<llvm::Module>, 3>>
148 SerializeToHsacoPass::loadLibraries(SmallVectorImpl<char> &path,
149                                     SmallVectorImpl<StringRef> &libraries,
150                                     llvm::LLVMContext &context) {
151   SmallVector<std::unique_ptr<llvm::Module>, 3> ret;
152   size_t dirLength = path.size();
153 
154   if (!llvm::sys::fs::is_directory(path)) {
155     getOperation().emitRemark() << "Bitcode path: " << path
156                                 << " does not exist or is not a directory\n";
157     return llvm::None;
158   }
159 
160   for (const StringRef file : libraries) {
161     llvm::SMDiagnostic error;
162     llvm::sys::path::append(path, file);
163     llvm::StringRef pathRef(path.data(), path.size());
164     std::unique_ptr<llvm::Module> library =
165         llvm::getLazyIRFileModule(pathRef, error, context);
166     path.truncate(dirLength);
167     if (!library) {
168       getOperation().emitError() << "Failed to load library " << file
169                                  << " from " << path << error.getMessage();
170       return llvm::None;
171     }
172     // Some ROCM builds don't strip this like they should
173     if (auto *openclVersion = library->getNamedMetadata("opencl.ocl.version"))
174       library->eraseNamedMetadata(openclVersion);
175     // Stop spamming us with clang version numbers
176     if (auto *ident = library->getNamedMetadata("llvm.ident"))
177       library->eraseNamedMetadata(ident);
178     ret.push_back(std::move(library));
179   }
180 
181   return ret;
182 }
183 
184 std::unique_ptr<llvm::Module>
185 SerializeToHsacoPass::translateToLLVMIR(llvm::LLVMContext &llvmContext) {
186   // MLIR -> LLVM translation
187   std::unique_ptr<llvm::Module> ret =
188       gpu::SerializeToBlobPass::translateToLLVMIR(llvmContext);
189 
190   if (!ret) {
191     getOperation().emitOpError("Module lowering failed");
192     return ret;
193   }
194   // Walk the LLVM module in order to determine if we need to link in device
195   // libs
196   bool needOpenCl = false;
197   bool needOckl = false;
198   bool needOcml = false;
199   for (llvm::Function &f : ret->functions()) {
200     if (f.hasExternalLinkage() && f.hasName() && !f.hasExactDefinition()) {
201       StringRef funcName = f.getName();
202       if ("printf" == funcName)
203         needOpenCl = true;
204       if (funcName.startswith("__ockl_"))
205         needOckl = true;
206       if (funcName.startswith("__ocml_"))
207         needOcml = true;
208     }
209   }
210 
211   if (needOpenCl)
212     needOcml = needOckl = true;
213 
214   // No libraries needed (the typical case)
215   if (!(needOpenCl || needOcml || needOckl))
216     return ret;
217 
218   // Define one of the control constants the ROCm device libraries expect to be
219   // present These constants can either be defined in the module or can be
220   // imported by linking in bitcode that defines the constant. To simplify our
221   // logic, we define the constants into the module we are compiling
222   auto addControlConstant = [&module = *ret](StringRef name, uint32_t value,
223                                              uint32_t bitwidth) {
224     using llvm::GlobalVariable;
225     if (module.getNamedGlobal(name)) {
226       return;
227     }
228     llvm::IntegerType *type =
229         llvm::IntegerType::getIntNTy(module.getContext(), bitwidth);
230     auto *initializer = llvm::ConstantInt::get(type, value, /*isSigned=*/false);
231     auto *constant = new GlobalVariable(
232         module, type,
233         /*isConstant=*/true, GlobalVariable::LinkageTypes::LinkOnceODRLinkage,
234         initializer, name,
235         /*before=*/nullptr,
236         /*threadLocalMode=*/GlobalVariable::ThreadLocalMode::NotThreadLocal,
237         /*addressSpace=*/4);
238     constant->setUnnamedAddr(GlobalVariable::UnnamedAddr::Local);
239     constant->setVisibility(
240         GlobalVariable::VisibilityTypes::ProtectedVisibility);
241     constant->setAlignment(llvm::MaybeAlign(bitwidth / 8));
242   };
243 
244   // Set up control variables in the module instead of linking in tiny bitcode
245   if (needOcml) {
246     // TODO(kdrewnia): Enable math optimizations once we have support for
247     // `-ffast-math`-like options
248     addControlConstant("__oclc_finite_only_opt", 0, 8);
249     addControlConstant("__oclc_daz_opt", 0, 8);
250     addControlConstant("__oclc_correctly_rounded_sqrt32", 1, 8);
251     addControlConstant("__oclc_unsafe_math_opt", 0, 8);
252   }
253   if (needOcml || needOckl) {
254     addControlConstant("__oclc_wavefrontsize64", 1, 8);
255     StringRef chipSet = this->chip.getValue();
256     if (chipSet.startswith("gfx"))
257       chipSet = chipSet.substr(3);
258     uint32_t minor =
259         llvm::APInt(32, chipSet.substr(chipSet.size() - 2), 16).getZExtValue();
260     uint32_t major = llvm::APInt(32, chipSet.substr(0, chipSet.size() - 2), 10)
261                          .getZExtValue();
262     uint32_t isaNumber = minor + 1000 * major;
263     addControlConstant("__oclc_ISA_version", isaNumber, 32);
264   }
265 
266   // Determine libraries we need to link - order matters due to dependencies
267   llvm::SmallVector<StringRef, 4> libraries;
268   if (needOpenCl)
269     libraries.push_back("opencl.bc");
270   if (needOcml)
271     libraries.push_back("ocml.bc");
272   if (needOckl)
273     libraries.push_back("ockl.bc");
274 
275   Optional<SmallVector<std::unique_ptr<llvm::Module>, 3>> mbModules;
276   std::string theRocmPath = getRocmPath();
277   llvm::SmallString<32> bitcodePath(std::move(theRocmPath));
278   llvm::sys::path::append(bitcodePath, "amdgcn", "bitcode");
279   mbModules = loadLibraries(bitcodePath, libraries, llvmContext);
280 
281   if (!mbModules) {
282     getOperation()
283             .emitWarning("Could not load required device libraries")
284             .attachNote()
285         << "This will probably cause link-time or run-time failures";
286     return ret; // We can still abort here
287   }
288 
289   llvm::Linker linker(*ret);
290   for (std::unique_ptr<llvm::Module> &libModule : mbModules.getValue()) {
291     // This bitcode linking code is substantially similar to what is used in
292     // hip-clang It imports the library functions into the module, allowing LLVM
293     // optimization passes (which must run after linking) to optimize across the
294     // libraries and the module's code. We also only import symbols if they are
295     // referenced by the module or a previous library since there will be no
296     // other source of references to those symbols in this compilation and since
297     // we don't want to bloat the resulting code object.
298     bool err = linker.linkInModule(
299         std::move(libModule), llvm::Linker::Flags::LinkOnlyNeeded,
300         [](llvm::Module &m, const StringSet<> &gvs) {
301           llvm::internalizeModule(m, [&gvs](const llvm::GlobalValue &gv) {
302             return !gv.hasName() || (gvs.count(gv.getName()) == 0);
303           });
304         });
305     // True is linker failure
306     if (err) {
307       getOperation().emitError(
308           "Unrecoverable failure during device library linking.");
309       // We have no guaranties about the state of `ret`, so bail
310       return nullptr;
311     }
312   }
313 
314   return ret;
315 }
316 
317 LogicalResult
318 SerializeToHsacoPass::optimizeLlvm(llvm::Module &llvmModule,
319                                    llvm::TargetMachine &targetMachine) {
320   int optLevel = this->optLevel.getValue();
321   if (optLevel < 0 || optLevel > 3)
322     return getOperation().emitError()
323            << "Invalid HSA optimization level" << optLevel << "\n";
324 
325   targetMachine.setOptLevel(static_cast<llvm::CodeGenOpt::Level>(optLevel));
326 
327   auto transformer =
328       makeOptimizingTransformer(optLevel, /*sizeLevel=*/0, &targetMachine);
329   auto error = transformer(&llvmModule);
330   if (error) {
331     InFlightDiagnostic mlirError = getOperation()->emitError();
332     llvm::handleAllErrors(
333         std::move(error), [&mlirError](const llvm::ErrorInfoBase &ei) {
334           mlirError << "Could not optimize LLVM IR: " << ei.message() << "\n";
335         });
336     return mlirError;
337   }
338   return success();
339 }
340 
341 std::unique_ptr<SmallVectorImpl<char>>
342 SerializeToHsacoPass::assembleIsa(const std::string &isa) {
343   auto loc = getOperation().getLoc();
344 
345   SmallVector<char, 0> result;
346   llvm::raw_svector_ostream os(result);
347 
348   llvm::Triple triple(llvm::Triple::normalize(this->triple));
349   std::string error;
350   const llvm::Target *target =
351       llvm::TargetRegistry::lookupTarget(triple.normalize(), error);
352   if (!target) {
353     emitError(loc, Twine("failed to lookup target: ") + error);
354     return {};
355   }
356 
357   llvm::SourceMgr srcMgr;
358   srcMgr.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(isa),
359                             SMLoc());
360 
361   const llvm::MCTargetOptions mcOptions;
362   std::unique_ptr<llvm::MCRegisterInfo> mri(
363       target->createMCRegInfo(this->triple));
364   std::unique_ptr<llvm::MCAsmInfo> mai(
365       target->createMCAsmInfo(*mri, this->triple, mcOptions));
366   mai->setRelaxELFRelocations(true);
367   std::unique_ptr<llvm::MCSubtargetInfo> sti(
368       target->createMCSubtargetInfo(this->triple, this->chip, this->features));
369 
370   llvm::MCContext ctx(triple, mai.get(), mri.get(), sti.get(), &srcMgr,
371                       &mcOptions);
372   std::unique_ptr<llvm::MCObjectFileInfo> mofi(target->createMCObjectFileInfo(
373       ctx, /*PIC=*/false, /*LargeCodeModel=*/false));
374   ctx.setObjectFileInfo(mofi.get());
375 
376   SmallString<128> cwd;
377   if (!llvm::sys::fs::current_path(cwd))
378     ctx.setCompilationDir(cwd);
379 
380   std::unique_ptr<llvm::MCStreamer> mcStreamer;
381   std::unique_ptr<llvm::MCInstrInfo> mcii(target->createMCInstrInfo());
382 
383   llvm::MCCodeEmitter *ce = target->createMCCodeEmitter(*mcii, ctx);
384   llvm::MCAsmBackend *mab = target->createMCAsmBackend(*sti, *mri, mcOptions);
385   mcStreamer.reset(target->createMCObjectStreamer(
386       triple, ctx, std::unique_ptr<llvm::MCAsmBackend>(mab),
387       mab->createObjectWriter(os), std::unique_ptr<llvm::MCCodeEmitter>(ce),
388       *sti, mcOptions.MCRelaxAll, mcOptions.MCIncrementalLinkerCompatible,
389       /*DWARFMustBeAtTheEnd*/ false));
390   mcStreamer->setUseAssemblerInfoForParsing(true);
391 
392   std::unique_ptr<llvm::MCAsmParser> parser(
393       createMCAsmParser(srcMgr, ctx, *mcStreamer, *mai));
394   std::unique_ptr<llvm::MCTargetAsmParser> tap(
395       target->createMCAsmParser(*sti, *parser, *mcii, mcOptions));
396 
397   if (!tap) {
398     emitError(loc, "assembler initialization error");
399     return {};
400   }
401 
402   parser->setTargetParser(*tap);
403   parser->Run(false);
404 
405   return std::make_unique<SmallVector<char, 0>>(std::move(result));
406 }
407 
408 std::unique_ptr<std::vector<char>>
409 SerializeToHsacoPass::createHsaco(const SmallVectorImpl<char> &isaBinary) {
410   auto loc = getOperation().getLoc();
411 
412   // Save the ISA binary to a temp file.
413   int tempIsaBinaryFd = -1;
414   SmallString<128> tempIsaBinaryFilename;
415   if (llvm::sys::fs::createTemporaryFile("kernel", "o", tempIsaBinaryFd,
416                                          tempIsaBinaryFilename)) {
417     emitError(loc, "temporary file for ISA binary creation error");
418     return {};
419   }
420   llvm::FileRemover cleanupIsaBinary(tempIsaBinaryFilename);
421   llvm::raw_fd_ostream tempIsaBinaryOs(tempIsaBinaryFd, true);
422   tempIsaBinaryOs << StringRef(isaBinary.data(), isaBinary.size());
423   tempIsaBinaryOs.close();
424 
425   // Create a temp file for HSA code object.
426   int tempHsacoFD = -1;
427   SmallString<128> tempHsacoFilename;
428   if (llvm::sys::fs::createTemporaryFile("kernel", "hsaco", tempHsacoFD,
429                                          tempHsacoFilename)) {
430     emitError(loc, "temporary file for HSA code object creation error");
431     return {};
432   }
433   llvm::FileRemover cleanupHsaco(tempHsacoFilename);
434 
435   std::string theRocmPath = getRocmPath();
436   llvm::SmallString<32> lldPath(std::move(theRocmPath));
437   llvm::sys::path::append(lldPath, "llvm", "bin", "ld.lld");
438   int lldResult = llvm::sys::ExecuteAndWait(
439       lldPath,
440       {"ld.lld", "-shared", tempIsaBinaryFilename, "-o", tempHsacoFilename});
441   if (lldResult != 0) {
442     emitError(loc, "lld invocation error");
443     return {};
444   }
445 
446   // Load the HSA code object.
447   auto hsacoFile = openInputFile(tempHsacoFilename);
448   if (!hsacoFile) {
449     emitError(loc, "read HSA code object from temp file error");
450     return {};
451   }
452 
453   StringRef buffer = hsacoFile->getBuffer();
454   return std::make_unique<std::vector<char>>(buffer.begin(), buffer.end());
455 }
456 
457 std::unique_ptr<std::vector<char>>
458 SerializeToHsacoPass::serializeISA(const std::string &isa) {
459   auto isaBinary = assembleIsa(isa);
460   if (!isaBinary)
461     return {};
462   return createHsaco(*isaBinary);
463 }
464 
465 // Register pass to serialize GPU kernel functions to a HSACO binary annotation.
466 void mlir::registerGpuSerializeToHsacoPass() {
467   PassRegistration<SerializeToHsacoPass> registerSerializeToHSACO(
468       [] {
469         // Initialize LLVM AMDGPU backend.
470         LLVMInitializeAMDGPUAsmParser();
471         LLVMInitializeAMDGPUAsmPrinter();
472         LLVMInitializeAMDGPUTarget();
473         LLVMInitializeAMDGPUTargetInfo();
474         LLVMInitializeAMDGPUTargetMC();
475 
476         return std::make_unique<SerializeToHsacoPass>("amdgcn-amd-amdhsa", "",
477                                                       "", 2);
478       });
479 }
480 
481 /// Create an instance of the GPU kernel function to HSAco binary serialization
482 /// pass.
483 std::unique_ptr<Pass> mlir::createGpuSerializeToHsacoPass(StringRef triple,
484                                                           StringRef arch,
485                                                           StringRef features,
486                                                           int optLevel) {
487   return std::make_unique<SerializeToHsacoPass>(triple, arch, features,
488                                                 optLevel);
489 }
490 
491 #else  // MLIR_GPU_TO_HSACO_PASS_ENABLE
492 void mlir::registerGpuSerializeToHsacoPass() {}
493 #endif // MLIR_GPU_TO_HSACO_PASS_ENABLE
494