1 //===- DXILOpLower.cpp - Lowering LLVM intrinsic to DIXLOp function -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file This file contains passes and utilities to lower llvm intrinsic call
10 /// to DXILOp function call.
11 //===----------------------------------------------------------------------===//
12 
13 #include "DXILConstants.h"
14 #include "DirectX.h"
15 #include "llvm/ADT/SmallVector.h"
16 #include "llvm/CodeGen/Passes.h"
17 #include "llvm/IR/IRBuilder.h"
18 #include "llvm/IR/Instruction.h"
19 #include "llvm/IR/Intrinsics.h"
20 #include "llvm/IR/Module.h"
21 #include "llvm/IR/PassManager.h"
22 #include "llvm/Pass.h"
23 #include "llvm/Support/ErrorHandling.h"
24 
25 #define DEBUG_TYPE "dxil-op-lower"
26 
27 using namespace llvm;
28 using namespace llvm::DXIL;
29 
30 constexpr StringLiteral DXILOpNamePrefix = "dx.op.";
31 
32 enum OverloadKind : uint16_t {
33   VOID = 1,
34   HALF = 1 << 1,
35   FLOAT = 1 << 2,
36   DOUBLE = 1 << 3,
37   I1 = 1 << 4,
38   I8 = 1 << 5,
39   I16 = 1 << 6,
40   I32 = 1 << 7,
41   I64 = 1 << 8,
42   UserDefineType = 1 << 9,
43   ObjectType = 1 << 10,
44 };
45 
46 static const char *getOverloadTypeName(OverloadKind Kind) {
47   switch (Kind) {
48   case OverloadKind::HALF:
49     return "f16";
50   case OverloadKind::FLOAT:
51     return "f32";
52   case OverloadKind::DOUBLE:
53     return "f64";
54   case OverloadKind::I1:
55     return "i1";
56   case OverloadKind::I8:
57     return "i8";
58   case OverloadKind::I16:
59     return "i16";
60   case OverloadKind::I32:
61     return "i32";
62   case OverloadKind::I64:
63     return "i64";
64   case OverloadKind::VOID:
65   case OverloadKind::ObjectType:
66   case OverloadKind::UserDefineType:
67     break;
68   }
69   llvm_unreachable("invalid overload type for name");
70   return "void";
71 }
72 
73 static OverloadKind getOverloadKind(Type *Ty) {
74   Type::TypeID T = Ty->getTypeID();
75   switch (T) {
76   case Type::VoidTyID:
77     return OverloadKind::VOID;
78   case Type::HalfTyID:
79     return OverloadKind::HALF;
80   case Type::FloatTyID:
81     return OverloadKind::FLOAT;
82   case Type::DoubleTyID:
83     return OverloadKind::DOUBLE;
84   case Type::IntegerTyID: {
85     IntegerType *ITy = cast<IntegerType>(Ty);
86     unsigned Bits = ITy->getBitWidth();
87     switch (Bits) {
88     case 1:
89       return OverloadKind::I1;
90     case 8:
91       return OverloadKind::I8;
92     case 16:
93       return OverloadKind::I16;
94     case 32:
95       return OverloadKind::I32;
96     case 64:
97       return OverloadKind::I64;
98     default:
99       llvm_unreachable("invalid overload type");
100       return OverloadKind::VOID;
101     }
102   }
103   case Type::PointerTyID:
104     return OverloadKind::UserDefineType;
105   case Type::StructTyID:
106     return OverloadKind::ObjectType;
107   default:
108     llvm_unreachable("invalid overload type");
109     return OverloadKind::VOID;
110   }
111 }
112 
113 static std::string getTypeName(OverloadKind Kind, Type *Ty) {
114   if (Kind < OverloadKind::UserDefineType) {
115     return getOverloadTypeName(Kind);
116   } else if (Kind == OverloadKind::UserDefineType) {
117     StructType *ST = cast<StructType>(Ty);
118     return ST->getStructName().str();
119   } else if (Kind == OverloadKind::ObjectType) {
120     StructType *ST = cast<StructType>(Ty);
121     return ST->getStructName().str();
122   } else {
123     std::string Str;
124     raw_string_ostream OS(Str);
125     Ty->print(OS);
126     return OS.str();
127   }
128 }
129 
130 // Static properties.
131 struct OpCodeProperty {
132   DXIL::OpCode OpCode;
133   // Offset in DXILOpCodeNameTable.
134   unsigned OpCodeNameOffset;
135   DXIL::OpCodeClass OpCodeClass;
136   // Offset in DXILOpCodeClassNameTable.
137   unsigned OpCodeClassNameOffset;
138   uint16_t OverloadTys;
139   llvm::Attribute::AttrKind FuncAttr;
140 };
141 
142 // Include getOpCodeClassName getOpCodeProperty and getOpCodeName which
143 // generated by tableGen.
144 #define DXIL_OP_OPERATION_TABLE
145 #include "DXILOperation.inc"
146 #undef DXIL_OP_OPERATION_TABLE
147 
148 static std::string constructOverloadName(OverloadKind Kind, Type *Ty,
149                                          const OpCodeProperty &Prop) {
150   if (Kind == OverloadKind::VOID) {
151     return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str();
152   }
153   return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." +
154           getTypeName(Kind, Ty))
155       .str();
156 }
157 
158 static FunctionCallee createDXILOpFunction(DXIL::OpCode DXILOp, Function &F,
159                                            Module &M) {
160   const OpCodeProperty *Prop = getOpCodeProperty(DXILOp);
161 
162   // Get return type as overload type for DXILOp.
163   // Only simple mapping case here, so return type is good enough.
164   Type *OverloadTy = F.getReturnType();
165 
166   OverloadKind Kind = getOverloadKind(OverloadTy);
167   // FIXME: find the issue and report error in clang instead of check it in
168   // backend.
169   if ((Prop->OverloadTys & (uint16_t)Kind) == 0) {
170     llvm_unreachable("invalid overload");
171   }
172 
173   std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop);
174   assert(!M.getFunction(FnName) && "Function already exists");
175 
176   auto &Ctx = M.getContext();
177   Type *OpCodeTy = Type::getInt32Ty(Ctx);
178 
179   SmallVector<Type *> ArgTypes;
180   // DXIL has i32 opcode as first arg.
181   ArgTypes.emplace_back(OpCodeTy);
182   FunctionType *FT = F.getFunctionType();
183   ArgTypes.append(FT->param_begin(), FT->param_end());
184   FunctionType *DXILOpFT = FunctionType::get(OverloadTy, ArgTypes, false);
185   return M.getOrInsertFunction(FnName, DXILOpFT);
186 }
187 
188 static void lowerIntrinsic(DXIL::OpCode DXILOp, Function &F, Module &M) {
189   auto DXILOpFn = createDXILOpFunction(DXILOp, F, M);
190   IRBuilder<> B(M.getContext());
191   Value *DXILOpArg = B.getInt32(static_cast<unsigned>(DXILOp));
192   for (User *U : make_early_inc_range(F.users())) {
193     CallInst *CI = dyn_cast<CallInst>(U);
194     if (!CI)
195       continue;
196 
197     SmallVector<Value *> Args;
198     Args.emplace_back(DXILOpArg);
199     Args.append(CI->arg_begin(), CI->arg_end());
200     B.SetInsertPoint(CI);
201     CallInst *DXILCI = B.CreateCall(DXILOpFn, Args);
202     LLVM_DEBUG(DXILCI->setName(getOpCodeName(DXILOp)));
203     CI->replaceAllUsesWith(DXILCI);
204     CI->eraseFromParent();
205   }
206   if (F.user_empty())
207     F.eraseFromParent();
208 }
209 
210 static bool lowerIntrinsics(Module &M) {
211   bool Updated = false;
212 
213 #define DXIL_OP_INTRINSIC_MAP
214 #include "DXILOperation.inc"
215 #undef DXIL_OP_INTRINSIC_MAP
216 
217   for (Function &F : make_early_inc_range(M.functions())) {
218     if (!F.isDeclaration())
219       continue;
220     Intrinsic::ID ID = F.getIntrinsicID();
221     if (ID == Intrinsic::not_intrinsic)
222       continue;
223     auto LowerIt = LowerMap.find(ID);
224     if (LowerIt == LowerMap.end())
225       continue;
226     lowerIntrinsic(LowerIt->second, F, M);
227     Updated = true;
228   }
229   return Updated;
230 }
231 
232 namespace {
233 /// A pass that transforms external global definitions into declarations.
234 class DXILOpLowering : public PassInfoMixin<DXILOpLowering> {
235 public:
236   PreservedAnalyses run(Module &M, ModuleAnalysisManager &) {
237     if (lowerIntrinsics(M))
238       return PreservedAnalyses::none();
239     return PreservedAnalyses::all();
240   }
241 };
242 } // namespace
243 
244 namespace {
245 class DXILOpLoweringLegacy : public ModulePass {
246 public:
247   bool runOnModule(Module &M) override { return lowerIntrinsics(M); }
248   StringRef getPassName() const override { return "DXIL Op Lowering"; }
249   DXILOpLoweringLegacy() : ModulePass(ID) {}
250 
251   static char ID; // Pass identification.
252 };
253 char DXILOpLoweringLegacy::ID = 0;
254 
255 } // end anonymous namespace
256 
257 INITIALIZE_PASS_BEGIN(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering",
258                       false, false)
259 INITIALIZE_PASS_END(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", false,
260                     false)
261 
262 ModulePass *llvm::createDXILOpLoweringLegacyPass() {
263   return new DXILOpLoweringLegacy();
264 }
265