10c7f7f1bSpython3kgae //===- DXILTranslateMetadata.cpp - Pass to emit DXIL metadata ---*- C++ -*-===//
20c7f7f1bSpython3kgae //
30c7f7f1bSpython3kgae // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
40c7f7f1bSpython3kgae // See https://llvm.org/LICENSE.txt for license information.
50c7f7f1bSpython3kgae // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60c7f7f1bSpython3kgae //
70c7f7f1bSpython3kgae //===----------------------------------------------------------------------===//
80c7f7f1bSpython3kgae ///
90c7f7f1bSpython3kgae //===----------------------------------------------------------------------===//
100c7f7f1bSpython3kgae 
110c7f7f1bSpython3kgae #include "DirectX.h"
12c6c13d4eSpython3kgae #include "llvm/ADT/StringSet.h"
130c7f7f1bSpython3kgae #include "llvm/ADT/Triple.h"
140c7f7f1bSpython3kgae #include "llvm/IR/Constants.h"
15c6c13d4eSpython3kgae #include "llvm/IR/Metadata.h"
160c7f7f1bSpython3kgae #include "llvm/IR/Module.h"
170c7f7f1bSpython3kgae #include "llvm/Pass.h"
180c7f7f1bSpython3kgae 
190c7f7f1bSpython3kgae using namespace llvm;
200c7f7f1bSpython3kgae 
ConstMDToUint32(const MDOperand & MDO)210c7f7f1bSpython3kgae static uint32_t ConstMDToUint32(const MDOperand &MDO) {
220c7f7f1bSpython3kgae   ConstantInt *pConst = mdconst::extract<ConstantInt>(MDO);
230c7f7f1bSpython3kgae   return (uint32_t)pConst->getZExtValue();
240c7f7f1bSpython3kgae }
250c7f7f1bSpython3kgae 
Uint32ToConstMD(unsigned v,LLVMContext & Ctx)260c7f7f1bSpython3kgae static ConstantAsMetadata *Uint32ToConstMD(unsigned v, LLVMContext &Ctx) {
270c7f7f1bSpython3kgae   return ConstantAsMetadata::get(
280c7f7f1bSpython3kgae       Constant::getIntegerValue(IntegerType::get(Ctx, 32), APInt(32, v)));
290c7f7f1bSpython3kgae }
300c7f7f1bSpython3kgae 
310c7f7f1bSpython3kgae constexpr StringLiteral ValVerKey = "dx.valver";
320c7f7f1bSpython3kgae constexpr unsigned DXILVersionNumFields = 2;
330c7f7f1bSpython3kgae 
emitDXILValidatorVersion(Module & M,VersionTuple & ValidatorVer)340c7f7f1bSpython3kgae static void emitDXILValidatorVersion(Module &M, VersionTuple &ValidatorVer) {
350c7f7f1bSpython3kgae   NamedMDNode *DXILValidatorVersionMD = M.getNamedMetadata(ValVerKey);
360c7f7f1bSpython3kgae 
370c7f7f1bSpython3kgae   // Allow re-writing the validator version, since this can be changed at
380c7f7f1bSpython3kgae   // later points.
390c7f7f1bSpython3kgae   if (DXILValidatorVersionMD)
400c7f7f1bSpython3kgae     M.eraseNamedMetadata(DXILValidatorVersionMD);
410c7f7f1bSpython3kgae 
420c7f7f1bSpython3kgae   DXILValidatorVersionMD = M.getOrInsertNamedMetadata(ValVerKey);
430c7f7f1bSpython3kgae 
440c7f7f1bSpython3kgae   auto &Ctx = M.getContext();
450c7f7f1bSpython3kgae   Metadata *MDVals[DXILVersionNumFields];
460c7f7f1bSpython3kgae   MDVals[0] = Uint32ToConstMD(ValidatorVer.getMajor(), Ctx);
47*129b531cSKazu Hirata   MDVals[1] = Uint32ToConstMD(ValidatorVer.getMinor().value_or(0), Ctx);
480c7f7f1bSpython3kgae 
490c7f7f1bSpython3kgae   DXILValidatorVersionMD->addOperand(MDNode::get(Ctx, MDVals));
500c7f7f1bSpython3kgae }
510c7f7f1bSpython3kgae 
loadDXILValidatorVersion(MDNode * ValVerMD)520c7f7f1bSpython3kgae static VersionTuple loadDXILValidatorVersion(MDNode *ValVerMD) {
530c7f7f1bSpython3kgae   if (ValVerMD->getNumOperands() != DXILVersionNumFields)
540c7f7f1bSpython3kgae     return VersionTuple();
550c7f7f1bSpython3kgae 
560c7f7f1bSpython3kgae   unsigned Major = ConstMDToUint32(ValVerMD->getOperand(0));
570c7f7f1bSpython3kgae   unsigned Minor = ConstMDToUint32(ValVerMD->getOperand(1));
580c7f7f1bSpython3kgae   return VersionTuple(Major, Minor);
590c7f7f1bSpython3kgae }
600c7f7f1bSpython3kgae 
cleanModuleFlags(Module & M)61c6c13d4eSpython3kgae static void cleanModuleFlags(Module &M) {
62c6c13d4eSpython3kgae   constexpr StringLiteral DeadKeys[] = {ValVerKey};
63c6c13d4eSpython3kgae   // Collect DeadKeys in ModuleFlags.
64c6c13d4eSpython3kgae   StringSet<> DeadKeySet;
65c6c13d4eSpython3kgae   for (auto &Key : DeadKeys) {
66c6c13d4eSpython3kgae     if (M.getModuleFlag(Key))
67c6c13d4eSpython3kgae       DeadKeySet.insert(Key);
680c7f7f1bSpython3kgae   }
69c6c13d4eSpython3kgae   if (DeadKeySet.empty())
70c6c13d4eSpython3kgae     return;
71c6c13d4eSpython3kgae 
72c6c13d4eSpython3kgae   SmallVector<Module::ModuleFlagEntry, 8> ModuleFlags;
73c6c13d4eSpython3kgae   M.getModuleFlagsMetadata(ModuleFlags);
74c6c13d4eSpython3kgae   NamedMDNode *MDFlags = M.getModuleFlagsMetadata();
75c6c13d4eSpython3kgae   MDFlags->eraseFromParent();
76c6c13d4eSpython3kgae   // Add ModuleFlag which not dead.
77c6c13d4eSpython3kgae   for (auto &Flag : ModuleFlags) {
78c6c13d4eSpython3kgae     StringRef Key = Flag.Key->getString();
79c6c13d4eSpython3kgae     if (DeadKeySet.contains(Key))
80c6c13d4eSpython3kgae       continue;
81c6c13d4eSpython3kgae     M.addModuleFlag(Flag.Behavior, Key, Flag.Val);
82c6c13d4eSpython3kgae   }
83c6c13d4eSpython3kgae }
84c6c13d4eSpython3kgae 
cleanModule(Module & M)85c6c13d4eSpython3kgae static void cleanModule(Module &M) { cleanModuleFlags(M); }
860c7f7f1bSpython3kgae 
870c7f7f1bSpython3kgae namespace {
880c7f7f1bSpython3kgae class DXILTranslateMetadata : public ModulePass {
890c7f7f1bSpython3kgae public:
900c7f7f1bSpython3kgae   static char ID; // Pass identification, replacement for typeid
DXILTranslateMetadata()910c7f7f1bSpython3kgae   explicit DXILTranslateMetadata() : ModulePass(ID), ValidatorVer(1, 0) {}
920c7f7f1bSpython3kgae 
getPassName() const930c7f7f1bSpython3kgae   StringRef getPassName() const override { return "DXIL Metadata Emit"; }
940c7f7f1bSpython3kgae 
950c7f7f1bSpython3kgae   bool runOnModule(Module &M) override;
960c7f7f1bSpython3kgae 
970c7f7f1bSpython3kgae private:
980c7f7f1bSpython3kgae   VersionTuple ValidatorVer;
990c7f7f1bSpython3kgae };
1000c7f7f1bSpython3kgae 
1010c7f7f1bSpython3kgae } // namespace
1020c7f7f1bSpython3kgae 
runOnModule(Module & M)1030c7f7f1bSpython3kgae bool DXILTranslateMetadata::runOnModule(Module &M) {
1040c7f7f1bSpython3kgae   if (MDNode *ValVerMD = cast_or_null<MDNode>(M.getModuleFlag(ValVerKey))) {
1050c7f7f1bSpython3kgae     auto ValVer = loadDXILValidatorVersion(ValVerMD);
1060c7f7f1bSpython3kgae     if (!ValVer.empty())
1070c7f7f1bSpython3kgae       ValidatorVer = ValVer;
1080c7f7f1bSpython3kgae   }
1090c7f7f1bSpython3kgae   emitDXILValidatorVersion(M, ValidatorVer);
1100c7f7f1bSpython3kgae   cleanModule(M);
1110c7f7f1bSpython3kgae   return false;
1120c7f7f1bSpython3kgae }
1130c7f7f1bSpython3kgae 
1140c7f7f1bSpython3kgae char DXILTranslateMetadata::ID = 0;
1150c7f7f1bSpython3kgae 
createDXILTranslateMetadataPass()1160c7f7f1bSpython3kgae ModulePass *llvm::createDXILTranslateMetadataPass() {
1170c7f7f1bSpython3kgae   return new DXILTranslateMetadata();
1180c7f7f1bSpython3kgae }
1190c7f7f1bSpython3kgae 
1200c7f7f1bSpython3kgae INITIALIZE_PASS(DXILTranslateMetadata, "dxil-metadata-emit",
1210c7f7f1bSpython3kgae                 "DXIL Metadata Emit", false, false)
122