1 //===- LowerGPUToHSACO.cpp - Convert GPU kernel to HSACO blob -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass that serializes a gpu module into HSAco blob and
10 // adds that blob as a string attribute of the module.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/GPU/Transforms/Passes.h"
15 #include "mlir/IR/Location.h"
16 #include "mlir/IR/MLIRContext.h"
17 
18 #if MLIR_GPU_TO_HSACO_PASS_ENABLE
19 #include "mlir/ExecutionEngine/OptUtils.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Support/FileUtilities.h"
22 #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"
23 #include "mlir/Target/LLVMIR/Export.h"
24 
25 #include "llvm/IR/Constants.h"
26 #include "llvm/IR/GlobalVariable.h"
27 #include "llvm/IR/Module.h"
28 #include "llvm/IRReader/IRReader.h"
29 #include "llvm/Linker/Linker.h"
30 
31 #include "llvm/MC/MCAsmBackend.h"
32 #include "llvm/MC/MCAsmInfo.h"
33 #include "llvm/MC/MCCodeEmitter.h"
34 #include "llvm/MC/MCContext.h"
35 #include "llvm/MC/MCInstrInfo.h"
36 #include "llvm/MC/MCObjectFileInfo.h"
37 #include "llvm/MC/MCObjectWriter.h"
38 #include "llvm/MC/MCParser/MCTargetAsmParser.h"
39 #include "llvm/MC/MCRegisterInfo.h"
40 #include "llvm/MC/MCStreamer.h"
41 #include "llvm/MC/MCSubtargetInfo.h"
42 #include "llvm/MC/TargetRegistry.h"
43 
44 #include "llvm/Support/CommandLine.h"
45 #include "llvm/Support/FileUtilities.h"
46 #include "llvm/Support/Path.h"
47 #include "llvm/Support/Program.h"
48 #include "llvm/Support/SourceMgr.h"
49 #include "llvm/Support/TargetSelect.h"
50 #include "llvm/Support/WithColor.h"
51 
52 #include "llvm/Target/TargetMachine.h"
53 #include "llvm/Target/TargetOptions.h"
54 
55 #include "llvm/Transforms/IPO/Internalize.h"
56 
57 #include <mutex>
58 
59 using namespace mlir;
60 
61 namespace {
62 class SerializeToHsacoPass
63     : public PassWrapper<SerializeToHsacoPass, gpu::SerializeToBlobPass> {
64 public:
65   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SerializeToHsacoPass)
66 
67   SerializeToHsacoPass(StringRef triple, StringRef arch, StringRef features,
68                        int optLevel);
69   SerializeToHsacoPass(const SerializeToHsacoPass &other);
70   StringRef getArgument() const override { return "gpu-to-hsaco"; }
71   StringRef getDescription() const override {
72     return "Lower GPU kernel function to HSACO binary annotations";
73   }
74 
75 protected:
76   Option<int> optLevel{
77       *this, "opt-level",
78       llvm::cl::desc("Optimization level for HSACO compilation"),
79       llvm::cl::init(2)};
80 
81   Option<std::string> rocmPath{*this, "rocm-path",
82                                llvm::cl::desc("Path to ROCm install")};
83 
84   // Overload to allow linking in device libs
85   std::unique_ptr<llvm::Module>
86   translateToLLVMIR(llvm::LLVMContext &llvmContext) override;
87 
88   /// Adds LLVM optimization passes
89   LogicalResult optimizeLlvm(llvm::Module &llvmModule,
90                              llvm::TargetMachine &targetMachine) override;
91 
92 private:
93   void getDependentDialects(DialectRegistry &registry) const override;
94 
95   // Loads LLVM bitcode libraries
96   Optional<SmallVector<std::unique_ptr<llvm::Module>, 3>>
97   loadLibraries(SmallVectorImpl<char> &path,
98                 SmallVectorImpl<StringRef> &libraries,
99                 llvm::LLVMContext &context);
100 
101   // Serializes ROCDL to HSACO.
102   std::unique_ptr<std::vector<char>>
103   serializeISA(const std::string &isa) override;
104 
105   std::unique_ptr<SmallVectorImpl<char>> assembleIsa(const std::string &isa);
106   std::unique_ptr<std::vector<char>>
107   createHsaco(const SmallVectorImpl<char> &isaBinary);
108 
109   std::string getRocmPath();
110 };
111 } // namespace
112 
113 SerializeToHsacoPass::SerializeToHsacoPass(const SerializeToHsacoPass &other)
114     : PassWrapper<SerializeToHsacoPass, gpu::SerializeToBlobPass>(other) {}
115 
116 /// Get a user-specified path to ROCm
117 // Tries, in order, the --rocm-path option, the ROCM_PATH environment variable
118 // and a compile-time default
119 std::string SerializeToHsacoPass::getRocmPath() {
120   if (rocmPath.getNumOccurrences() > 0)
121     return rocmPath.getValue();
122 
123   return __DEFAULT_ROCM_PATH__;
124 }
125 
126 // Sets the 'option' to 'value' unless it already has a value.
127 static void maybeSetOption(Pass::Option<std::string> &option,
128                            function_ref<std::string()> getValue) {
129   if (!option.hasValue())
130     option = getValue();
131 }
132 
133 SerializeToHsacoPass::SerializeToHsacoPass(StringRef triple, StringRef arch,
134                                            StringRef features, int optLevel) {
135   maybeSetOption(this->triple, [&triple] { return triple.str(); });
136   maybeSetOption(this->chip, [&arch] { return arch.str(); });
137   maybeSetOption(this->features, [&features] { return features.str(); });
138   if (this->optLevel.getNumOccurrences() == 0)
139     this->optLevel.setValue(optLevel);
140 }
141 
142 void SerializeToHsacoPass::getDependentDialects(
143     DialectRegistry &registry) const {
144   registerROCDLDialectTranslation(registry);
145   gpu::SerializeToBlobPass::getDependentDialects(registry);
146 }
147 
148 Optional<SmallVector<std::unique_ptr<llvm::Module>, 3>>
149 SerializeToHsacoPass::loadLibraries(SmallVectorImpl<char> &path,
150                                     SmallVectorImpl<StringRef> &libraries,
151                                     llvm::LLVMContext &context) {
152   SmallVector<std::unique_ptr<llvm::Module>, 3> ret;
153   size_t dirLength = path.size();
154 
155   if (!llvm::sys::fs::is_directory(path)) {
156     getOperation().emitRemark() << "Bitcode path: " << path
157                                 << " does not exist or is not a directory\n";
158     return llvm::None;
159   }
160 
161   for (const StringRef file : libraries) {
162     llvm::SMDiagnostic error;
163     llvm::sys::path::append(path, file);
164     llvm::StringRef pathRef(path.data(), path.size());
165     std::unique_ptr<llvm::Module> library =
166         llvm::getLazyIRFileModule(pathRef, error, context);
167     path.truncate(dirLength);
168     if (!library) {
169       getOperation().emitError() << "Failed to load library " << file
170                                  << " from " << path << error.getMessage();
171       return llvm::None;
172     }
173     // Some ROCM builds don't strip this like they should
174     if (auto *openclVersion = library->getNamedMetadata("opencl.ocl.version"))
175       library->eraseNamedMetadata(openclVersion);
176     // Stop spamming us with clang version numbers
177     if (auto *ident = library->getNamedMetadata("llvm.ident"))
178       library->eraseNamedMetadata(ident);
179     ret.push_back(std::move(library));
180   }
181 
182   return ret;
183 }
184 
185 std::unique_ptr<llvm::Module>
186 SerializeToHsacoPass::translateToLLVMIR(llvm::LLVMContext &llvmContext) {
187   // MLIR -> LLVM translation
188   std::unique_ptr<llvm::Module> ret =
189       gpu::SerializeToBlobPass::translateToLLVMIR(llvmContext);
190 
191   if (!ret) {
192     getOperation().emitOpError("Module lowering failed");
193     return ret;
194   }
195   // Walk the LLVM module in order to determine if we need to link in device
196   // libs
197   bool needOpenCl = false;
198   bool needOckl = false;
199   bool needOcml = false;
200   for (llvm::Function &f : ret->functions()) {
201     if (f.hasExternalLinkage() && f.hasName() && !f.hasExactDefinition()) {
202       StringRef funcName = f.getName();
203       if ("printf" == funcName)
204         needOpenCl = true;
205       if (funcName.startswith("__ockl_"))
206         needOckl = true;
207       if (funcName.startswith("__ocml_"))
208         needOcml = true;
209     }
210   }
211 
212   if (needOpenCl)
213     needOcml = needOckl = true;
214 
215   // No libraries needed (the typical case)
216   if (!(needOpenCl || needOcml || needOckl))
217     return ret;
218 
219   // Define one of the control constants the ROCm device libraries expect to be
220   // present These constants can either be defined in the module or can be
221   // imported by linking in bitcode that defines the constant. To simplify our
222   // logic, we define the constants into the module we are compiling
223   auto addControlConstant = [&module = *ret](StringRef name, uint32_t value,
224                                              uint32_t bitwidth) {
225     using llvm::GlobalVariable;
226     if (module.getNamedGlobal(name)) {
227       return;
228     }
229     llvm::IntegerType *type =
230         llvm::IntegerType::getIntNTy(module.getContext(), bitwidth);
231     auto *initializer = llvm::ConstantInt::get(type, value, /*isSigned=*/false);
232     auto *constant = new GlobalVariable(
233         module, type,
234         /*isConstant=*/true, GlobalVariable::LinkageTypes::LinkOnceODRLinkage,
235         initializer, name,
236         /*before=*/nullptr,
237         /*threadLocalMode=*/GlobalVariable::ThreadLocalMode::NotThreadLocal,
238         /*addressSpace=*/4);
239     constant->setUnnamedAddr(GlobalVariable::UnnamedAddr::Local);
240     constant->setVisibility(
241         GlobalVariable::VisibilityTypes::ProtectedVisibility);
242     constant->setAlignment(llvm::MaybeAlign(bitwidth / 8));
243   };
244 
245   // Set up control variables in the module instead of linking in tiny bitcode
246   if (needOcml) {
247     // TODO(kdrewnia): Enable math optimizations once we have support for
248     // `-ffast-math`-like options
249     addControlConstant("__oclc_finite_only_opt", 0, 8);
250     addControlConstant("__oclc_daz_opt", 0, 8);
251     addControlConstant("__oclc_correctly_rounded_sqrt32", 1, 8);
252     addControlConstant("__oclc_unsafe_math_opt", 0, 8);
253   }
254   if (needOcml || needOckl) {
255     addControlConstant("__oclc_wavefrontsize64", 1, 8);
256     StringRef chipSet = this->chip.getValue();
257     if (chipSet.startswith("gfx"))
258       chipSet = chipSet.substr(3);
259     uint32_t minor =
260         llvm::APInt(32, chipSet.substr(chipSet.size() - 2), 16).getZExtValue();
261     uint32_t major = llvm::APInt(32, chipSet.substr(0, chipSet.size() - 2), 10)
262                          .getZExtValue();
263     uint32_t isaNumber = minor + 1000 * major;
264     addControlConstant("__oclc_ISA_version", isaNumber, 32);
265   }
266 
267   // Determine libraries we need to link - order matters due to dependencies
268   llvm::SmallVector<StringRef, 4> libraries;
269   if (needOpenCl)
270     libraries.push_back("opencl.bc");
271   if (needOcml)
272     libraries.push_back("ocml.bc");
273   if (needOckl)
274     libraries.push_back("ockl.bc");
275 
276   Optional<SmallVector<std::unique_ptr<llvm::Module>, 3>> mbModules;
277   std::string theRocmPath = getRocmPath();
278   llvm::SmallString<32> bitcodePath(theRocmPath);
279   llvm::sys::path::append(bitcodePath, "amdgcn", "bitcode");
280   mbModules = loadLibraries(bitcodePath, libraries, llvmContext);
281 
282   if (!mbModules) {
283     getOperation()
284             .emitWarning("Could not load required device libraries")
285             .attachNote()
286         << "This will probably cause link-time or run-time failures";
287     return ret; // We can still abort here
288   }
289 
290   llvm::Linker linker(*ret);
291   for (std::unique_ptr<llvm::Module> &libModule : mbModules.getValue()) {
292     // This bitcode linking code is substantially similar to what is used in
293     // hip-clang It imports the library functions into the module, allowing LLVM
294     // optimization passes (which must run after linking) to optimize across the
295     // libraries and the module's code. We also only import symbols if they are
296     // referenced by the module or a previous library since there will be no
297     // other source of references to those symbols in this compilation and since
298     // we don't want to bloat the resulting code object.
299     bool err = linker.linkInModule(
300         std::move(libModule), llvm::Linker::Flags::LinkOnlyNeeded,
301         [](llvm::Module &m, const StringSet<> &gvs) {
302           llvm::internalizeModule(m, [&gvs](const llvm::GlobalValue &gv) {
303             return !gv.hasName() || (gvs.count(gv.getName()) == 0);
304           });
305         });
306     // True is linker failure
307     if (err) {
308       getOperation().emitError(
309           "Unrecoverable failure during device library linking.");
310       // We have no guaranties about the state of `ret`, so bail
311       return nullptr;
312     }
313   }
314 
315   return ret;
316 }
317 
318 LogicalResult
319 SerializeToHsacoPass::optimizeLlvm(llvm::Module &llvmModule,
320                                    llvm::TargetMachine &targetMachine) {
321   int optLevel = this->optLevel.getValue();
322   if (optLevel < 0 || optLevel > 3)
323     return getOperation().emitError()
324            << "Invalid HSA optimization level" << optLevel << "\n";
325 
326   targetMachine.setOptLevel(static_cast<llvm::CodeGenOpt::Level>(optLevel));
327 
328   auto transformer =
329       makeOptimizingTransformer(optLevel, /*sizeLevel=*/0, &targetMachine);
330   auto error = transformer(&llvmModule);
331   if (error) {
332     InFlightDiagnostic mlirError = getOperation()->emitError();
333     llvm::handleAllErrors(
334         std::move(error), [&mlirError](const llvm::ErrorInfoBase &ei) {
335           mlirError << "Could not optimize LLVM IR: " << ei.message() << "\n";
336         });
337     return mlirError;
338   }
339   return success();
340 }
341 
342 std::unique_ptr<SmallVectorImpl<char>>
343 SerializeToHsacoPass::assembleIsa(const std::string &isa) {
344   auto loc = getOperation().getLoc();
345 
346   SmallVector<char, 0> result;
347   llvm::raw_svector_ostream os(result);
348 
349   llvm::Triple triple(llvm::Triple::normalize(this->triple));
350   std::string error;
351   const llvm::Target *target =
352       llvm::TargetRegistry::lookupTarget(triple.normalize(), error);
353   if (!target) {
354     emitError(loc, Twine("failed to lookup target: ") + error);
355     return {};
356   }
357 
358   llvm::SourceMgr srcMgr;
359   srcMgr.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(isa),
360                             SMLoc());
361 
362   const llvm::MCTargetOptions mcOptions;
363   std::unique_ptr<llvm::MCRegisterInfo> mri(
364       target->createMCRegInfo(this->triple));
365   std::unique_ptr<llvm::MCAsmInfo> mai(
366       target->createMCAsmInfo(*mri, this->triple, mcOptions));
367   mai->setRelaxELFRelocations(true);
368   std::unique_ptr<llvm::MCSubtargetInfo> sti(
369       target->createMCSubtargetInfo(this->triple, this->chip, this->features));
370 
371   llvm::MCContext ctx(triple, mai.get(), mri.get(), sti.get(), &srcMgr,
372                       &mcOptions);
373   std::unique_ptr<llvm::MCObjectFileInfo> mofi(target->createMCObjectFileInfo(
374       ctx, /*PIC=*/false, /*LargeCodeModel=*/false));
375   ctx.setObjectFileInfo(mofi.get());
376 
377   SmallString<128> cwd;
378   if (!llvm::sys::fs::current_path(cwd))
379     ctx.setCompilationDir(cwd);
380 
381   std::unique_ptr<llvm::MCStreamer> mcStreamer;
382   std::unique_ptr<llvm::MCInstrInfo> mcii(target->createMCInstrInfo());
383 
384   llvm::MCCodeEmitter *ce = target->createMCCodeEmitter(*mcii, ctx);
385   llvm::MCAsmBackend *mab = target->createMCAsmBackend(*sti, *mri, mcOptions);
386   mcStreamer.reset(target->createMCObjectStreamer(
387       triple, ctx, std::unique_ptr<llvm::MCAsmBackend>(mab),
388       mab->createObjectWriter(os), std::unique_ptr<llvm::MCCodeEmitter>(ce),
389       *sti, mcOptions.MCRelaxAll, mcOptions.MCIncrementalLinkerCompatible,
390       /*DWARFMustBeAtTheEnd*/ false));
391   mcStreamer->setUseAssemblerInfoForParsing(true);
392 
393   std::unique_ptr<llvm::MCAsmParser> parser(
394       createMCAsmParser(srcMgr, ctx, *mcStreamer, *mai));
395   std::unique_ptr<llvm::MCTargetAsmParser> tap(
396       target->createMCAsmParser(*sti, *parser, *mcii, mcOptions));
397 
398   if (!tap) {
399     emitError(loc, "assembler initialization error");
400     return {};
401   }
402 
403   parser->setTargetParser(*tap);
404   parser->Run(false);
405 
406   return std::make_unique<SmallVector<char, 0>>(std::move(result));
407 }
408 
409 std::unique_ptr<std::vector<char>>
410 SerializeToHsacoPass::createHsaco(const SmallVectorImpl<char> &isaBinary) {
411   auto loc = getOperation().getLoc();
412 
413   // Save the ISA binary to a temp file.
414   int tempIsaBinaryFd = -1;
415   SmallString<128> tempIsaBinaryFilename;
416   if (llvm::sys::fs::createTemporaryFile("kernel", "o", tempIsaBinaryFd,
417                                          tempIsaBinaryFilename)) {
418     emitError(loc, "temporary file for ISA binary creation error");
419     return {};
420   }
421   llvm::FileRemover cleanupIsaBinary(tempIsaBinaryFilename);
422   llvm::raw_fd_ostream tempIsaBinaryOs(tempIsaBinaryFd, true);
423   tempIsaBinaryOs << StringRef(isaBinary.data(), isaBinary.size());
424   tempIsaBinaryOs.close();
425 
426   // Create a temp file for HSA code object.
427   int tempHsacoFD = -1;
428   SmallString<128> tempHsacoFilename;
429   if (llvm::sys::fs::createTemporaryFile("kernel", "hsaco", tempHsacoFD,
430                                          tempHsacoFilename)) {
431     emitError(loc, "temporary file for HSA code object creation error");
432     return {};
433   }
434   llvm::FileRemover cleanupHsaco(tempHsacoFilename);
435 
436   std::string theRocmPath = getRocmPath();
437   llvm::SmallString<32> lldPath(theRocmPath);
438   llvm::sys::path::append(lldPath, "llvm", "bin", "ld.lld");
439   int lldResult = llvm::sys::ExecuteAndWait(
440       lldPath,
441       {"ld.lld", "-shared", tempIsaBinaryFilename, "-o", tempHsacoFilename});
442   if (lldResult != 0) {
443     emitError(loc, "lld invocation error");
444     return {};
445   }
446 
447   // Load the HSA code object.
448   auto hsacoFile = openInputFile(tempHsacoFilename);
449   if (!hsacoFile) {
450     emitError(loc, "read HSA code object from temp file error");
451     return {};
452   }
453 
454   StringRef buffer = hsacoFile->getBuffer();
455   return std::make_unique<std::vector<char>>(buffer.begin(), buffer.end());
456 }
457 
458 std::unique_ptr<std::vector<char>>
459 SerializeToHsacoPass::serializeISA(const std::string &isa) {
460   auto isaBinary = assembleIsa(isa);
461   if (!isaBinary)
462     return {};
463   return createHsaco(*isaBinary);
464 }
465 
466 // Register pass to serialize GPU kernel functions to a HSACO binary annotation.
467 void mlir::registerGpuSerializeToHsacoPass() {
468   PassRegistration<SerializeToHsacoPass> registerSerializeToHSACO(
469       [] {
470         // Initialize LLVM AMDGPU backend.
471         LLVMInitializeAMDGPUAsmParser();
472         LLVMInitializeAMDGPUAsmPrinter();
473         LLVMInitializeAMDGPUTarget();
474         LLVMInitializeAMDGPUTargetInfo();
475         LLVMInitializeAMDGPUTargetMC();
476 
477         return std::make_unique<SerializeToHsacoPass>("amdgcn-amd-amdhsa", "",
478                                                       "", 2);
479       });
480 }
481 
482 /// Create an instance of the GPU kernel function to HSAco binary serialization
483 /// pass.
484 std::unique_ptr<Pass> mlir::createGpuSerializeToHsacoPass(StringRef triple,
485                                                           StringRef arch,
486                                                           StringRef features,
487                                                           int optLevel) {
488   return std::make_unique<SerializeToHsacoPass>(triple, arch, features,
489                                                 optLevel);
490 }
491 
492 #else  // MLIR_GPU_TO_HSACO_PASS_ENABLE
493 void mlir::registerGpuSerializeToHsacoPass() {}
494 #endif // MLIR_GPU_TO_HSACO_PASS_ENABLE
495