1*57006b14SXiang Li //===- DXILOpBuilder.cpp - Helper class for build DIXLOp functions --------===//
2*57006b14SXiang Li //
3*57006b14SXiang Li // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*57006b14SXiang Li // See https://llvm.org/LICENSE.txt for license information.
5*57006b14SXiang Li // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*57006b14SXiang Li //
7*57006b14SXiang Li //===----------------------------------------------------------------------===//
8*57006b14SXiang Li ///
9*57006b14SXiang Li /// \file This file contains class to help build DXIL op functions.
10*57006b14SXiang Li //===----------------------------------------------------------------------===//
11*57006b14SXiang Li 
12*57006b14SXiang Li #include "DXILOpBuilder.h"
13*57006b14SXiang Li #include "DXILConstants.h"
14*57006b14SXiang Li #include "llvm/IR/IRBuilder.h"
15*57006b14SXiang Li #include "llvm/IR/Module.h"
16*57006b14SXiang Li #include "llvm/Support/DXILOperationCommon.h"
17*57006b14SXiang Li #include "llvm/Support/ErrorHandling.h"
18*57006b14SXiang Li 
19*57006b14SXiang Li using namespace llvm;
20*57006b14SXiang Li using namespace llvm::DXIL;
21*57006b14SXiang Li 
22*57006b14SXiang Li constexpr StringLiteral DXILOpNamePrefix = "dx.op.";
23*57006b14SXiang Li 
24*57006b14SXiang Li namespace {
25*57006b14SXiang Li 
26*57006b14SXiang Li enum OverloadKind : uint16_t {
27*57006b14SXiang Li   VOID = 1,
28*57006b14SXiang Li   HALF = 1 << 1,
29*57006b14SXiang Li   FLOAT = 1 << 2,
30*57006b14SXiang Li   DOUBLE = 1 << 3,
31*57006b14SXiang Li   I1 = 1 << 4,
32*57006b14SXiang Li   I8 = 1 << 5,
33*57006b14SXiang Li   I16 = 1 << 6,
34*57006b14SXiang Li   I32 = 1 << 7,
35*57006b14SXiang Li   I64 = 1 << 8,
36*57006b14SXiang Li   UserDefineType = 1 << 9,
37*57006b14SXiang Li   ObjectType = 1 << 10,
38*57006b14SXiang Li };
39*57006b14SXiang Li 
40*57006b14SXiang Li } // namespace
41*57006b14SXiang Li 
getOverloadTypeName(OverloadKind Kind)42*57006b14SXiang Li static const char *getOverloadTypeName(OverloadKind Kind) {
43*57006b14SXiang Li   switch (Kind) {
44*57006b14SXiang Li   case OverloadKind::HALF:
45*57006b14SXiang Li     return "f16";
46*57006b14SXiang Li   case OverloadKind::FLOAT:
47*57006b14SXiang Li     return "f32";
48*57006b14SXiang Li   case OverloadKind::DOUBLE:
49*57006b14SXiang Li     return "f64";
50*57006b14SXiang Li   case OverloadKind::I1:
51*57006b14SXiang Li     return "i1";
52*57006b14SXiang Li   case OverloadKind::I8:
53*57006b14SXiang Li     return "i8";
54*57006b14SXiang Li   case OverloadKind::I16:
55*57006b14SXiang Li     return "i16";
56*57006b14SXiang Li   case OverloadKind::I32:
57*57006b14SXiang Li     return "i32";
58*57006b14SXiang Li   case OverloadKind::I64:
59*57006b14SXiang Li     return "i64";
60*57006b14SXiang Li   case OverloadKind::VOID:
61*57006b14SXiang Li   case OverloadKind::ObjectType:
62*57006b14SXiang Li   case OverloadKind::UserDefineType:
63*57006b14SXiang Li     break;
64*57006b14SXiang Li   }
65*57006b14SXiang Li   llvm_unreachable("invalid overload type for name");
66*57006b14SXiang Li   return "void";
67*57006b14SXiang Li }
68*57006b14SXiang Li 
getOverloadKind(Type * Ty)69*57006b14SXiang Li static OverloadKind getOverloadKind(Type *Ty) {
70*57006b14SXiang Li   Type::TypeID T = Ty->getTypeID();
71*57006b14SXiang Li   switch (T) {
72*57006b14SXiang Li   case Type::VoidTyID:
73*57006b14SXiang Li     return OverloadKind::VOID;
74*57006b14SXiang Li   case Type::HalfTyID:
75*57006b14SXiang Li     return OverloadKind::HALF;
76*57006b14SXiang Li   case Type::FloatTyID:
77*57006b14SXiang Li     return OverloadKind::FLOAT;
78*57006b14SXiang Li   case Type::DoubleTyID:
79*57006b14SXiang Li     return OverloadKind::DOUBLE;
80*57006b14SXiang Li   case Type::IntegerTyID: {
81*57006b14SXiang Li     IntegerType *ITy = cast<IntegerType>(Ty);
82*57006b14SXiang Li     unsigned Bits = ITy->getBitWidth();
83*57006b14SXiang Li     switch (Bits) {
84*57006b14SXiang Li     case 1:
85*57006b14SXiang Li       return OverloadKind::I1;
86*57006b14SXiang Li     case 8:
87*57006b14SXiang Li       return OverloadKind::I8;
88*57006b14SXiang Li     case 16:
89*57006b14SXiang Li       return OverloadKind::I16;
90*57006b14SXiang Li     case 32:
91*57006b14SXiang Li       return OverloadKind::I32;
92*57006b14SXiang Li     case 64:
93*57006b14SXiang Li       return OverloadKind::I64;
94*57006b14SXiang Li     default:
95*57006b14SXiang Li       llvm_unreachable("invalid overload type");
96*57006b14SXiang Li       return OverloadKind::VOID;
97*57006b14SXiang Li     }
98*57006b14SXiang Li   }
99*57006b14SXiang Li   case Type::PointerTyID:
100*57006b14SXiang Li     return OverloadKind::UserDefineType;
101*57006b14SXiang Li   case Type::StructTyID:
102*57006b14SXiang Li     return OverloadKind::ObjectType;
103*57006b14SXiang Li   default:
104*57006b14SXiang Li     llvm_unreachable("invalid overload type");
105*57006b14SXiang Li     return OverloadKind::VOID;
106*57006b14SXiang Li   }
107*57006b14SXiang Li }
108*57006b14SXiang Li 
getTypeName(OverloadKind Kind,Type * Ty)109*57006b14SXiang Li static std::string getTypeName(OverloadKind Kind, Type *Ty) {
110*57006b14SXiang Li   if (Kind < OverloadKind::UserDefineType) {
111*57006b14SXiang Li     return getOverloadTypeName(Kind);
112*57006b14SXiang Li   } else if (Kind == OverloadKind::UserDefineType) {
113*57006b14SXiang Li     StructType *ST = cast<StructType>(Ty);
114*57006b14SXiang Li     return ST->getStructName().str();
115*57006b14SXiang Li   } else if (Kind == OverloadKind::ObjectType) {
116*57006b14SXiang Li     StructType *ST = cast<StructType>(Ty);
117*57006b14SXiang Li     return ST->getStructName().str();
118*57006b14SXiang Li   } else {
119*57006b14SXiang Li     std::string Str;
120*57006b14SXiang Li     raw_string_ostream OS(Str);
121*57006b14SXiang Li     Ty->print(OS);
122*57006b14SXiang Li     return OS.str();
123*57006b14SXiang Li   }
124*57006b14SXiang Li }
125*57006b14SXiang Li 
126*57006b14SXiang Li // Static properties.
127*57006b14SXiang Li struct OpCodeProperty {
128*57006b14SXiang Li   DXIL::OpCode OpCode;
129*57006b14SXiang Li   // Offset in DXILOpCodeNameTable.
130*57006b14SXiang Li   unsigned OpCodeNameOffset;
131*57006b14SXiang Li   DXIL::OpCodeClass OpCodeClass;
132*57006b14SXiang Li   // Offset in DXILOpCodeClassNameTable.
133*57006b14SXiang Li   unsigned OpCodeClassNameOffset;
134*57006b14SXiang Li   uint16_t OverloadTys;
135*57006b14SXiang Li   llvm::Attribute::AttrKind FuncAttr;
136*57006b14SXiang Li   int OverloadParamIndex;        // parameter index which control the overload.
137*57006b14SXiang Li                                  // When < 0, should be only 1 overload type.
138*57006b14SXiang Li   unsigned NumOfParameters;      // Number of parameters include return value.
139*57006b14SXiang Li   unsigned ParameterTableOffset; // Offset in ParameterTable.
140*57006b14SXiang Li };
141*57006b14SXiang Li 
142*57006b14SXiang Li // Include getOpCodeClassName getOpCodeProperty, getOpCodeName and
143*57006b14SXiang Li // getOpCodeParameterKind which generated by tableGen.
144*57006b14SXiang Li #define DXIL_OP_OPERATION_TABLE
145*57006b14SXiang Li #include "DXILOperation.inc"
146*57006b14SXiang Li #undef DXIL_OP_OPERATION_TABLE
147*57006b14SXiang Li 
constructOverloadName(OverloadKind Kind,Type * Ty,const OpCodeProperty & Prop)148*57006b14SXiang Li static std::string constructOverloadName(OverloadKind Kind, Type *Ty,
149*57006b14SXiang Li                                          const OpCodeProperty &Prop) {
150*57006b14SXiang Li   if (Kind == OverloadKind::VOID) {
151*57006b14SXiang Li     return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str();
152*57006b14SXiang Li   }
153*57006b14SXiang Li   return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." +
154*57006b14SXiang Li           getTypeName(Kind, Ty))
155*57006b14SXiang Li       .str();
156*57006b14SXiang Li }
157*57006b14SXiang Li 
constructOverloadTypeName(OverloadKind Kind,StringRef TypeName)158*57006b14SXiang Li static std::string constructOverloadTypeName(OverloadKind Kind,
159*57006b14SXiang Li                                              StringRef TypeName) {
160*57006b14SXiang Li   if (Kind == OverloadKind::VOID)
161*57006b14SXiang Li     return TypeName.str();
162*57006b14SXiang Li 
163*57006b14SXiang Li   assert(Kind < OverloadKind::UserDefineType && "invalid overload kind");
164*57006b14SXiang Li   return (Twine(TypeName) + getOverloadTypeName(Kind)).str();
165*57006b14SXiang Li }
166*57006b14SXiang Li 
getOrCreateStructType(StringRef Name,ArrayRef<Type * > EltTys,LLVMContext & Ctx)167*57006b14SXiang Li static StructType *getOrCreateStructType(StringRef Name,
168*57006b14SXiang Li                                          ArrayRef<Type *> EltTys,
169*57006b14SXiang Li                                          LLVMContext &Ctx) {
170*57006b14SXiang Li   StructType *ST = StructType::getTypeByName(Ctx, Name);
171*57006b14SXiang Li   if (ST)
172*57006b14SXiang Li     return ST;
173*57006b14SXiang Li 
174*57006b14SXiang Li   return StructType::create(Ctx, EltTys, Name);
175*57006b14SXiang Li }
176*57006b14SXiang Li 
getResRetType(Type * OverloadTy,LLVMContext & Ctx)177*57006b14SXiang Li static StructType *getResRetType(Type *OverloadTy, LLVMContext &Ctx) {
178*57006b14SXiang Li   OverloadKind Kind = getOverloadKind(OverloadTy);
179*57006b14SXiang Li   std::string TypeName = constructOverloadTypeName(Kind, "dx.types.ResRet.");
180*57006b14SXiang Li   Type *FieldTypes[5] = {OverloadTy, OverloadTy, OverloadTy, OverloadTy,
181*57006b14SXiang Li                          Type::getInt32Ty(Ctx)};
182*57006b14SXiang Li   return getOrCreateStructType(TypeName, FieldTypes, Ctx);
183*57006b14SXiang Li }
184*57006b14SXiang Li 
getHandleType(LLVMContext & Ctx)185*57006b14SXiang Li static StructType *getHandleType(LLVMContext &Ctx) {
186*57006b14SXiang Li   return getOrCreateStructType("dx.types.Handle", Type::getInt8PtrTy(Ctx), Ctx);
187*57006b14SXiang Li }
188*57006b14SXiang Li 
getTypeFromParameterKind(ParameterKind Kind,Type * OverloadTy)189*57006b14SXiang Li static Type *getTypeFromParameterKind(ParameterKind Kind, Type *OverloadTy) {
190*57006b14SXiang Li   auto &Ctx = OverloadTy->getContext();
191*57006b14SXiang Li   switch (Kind) {
192*57006b14SXiang Li   case ParameterKind::VOID:
193*57006b14SXiang Li     return Type::getVoidTy(Ctx);
194*57006b14SXiang Li   case ParameterKind::HALF:
195*57006b14SXiang Li     return Type::getHalfTy(Ctx);
196*57006b14SXiang Li   case ParameterKind::FLOAT:
197*57006b14SXiang Li     return Type::getFloatTy(Ctx);
198*57006b14SXiang Li   case ParameterKind::DOUBLE:
199*57006b14SXiang Li     return Type::getDoubleTy(Ctx);
200*57006b14SXiang Li   case ParameterKind::I1:
201*57006b14SXiang Li     return Type::getInt1Ty(Ctx);
202*57006b14SXiang Li   case ParameterKind::I8:
203*57006b14SXiang Li     return Type::getInt8Ty(Ctx);
204*57006b14SXiang Li   case ParameterKind::I16:
205*57006b14SXiang Li     return Type::getInt16Ty(Ctx);
206*57006b14SXiang Li   case ParameterKind::I32:
207*57006b14SXiang Li     return Type::getInt32Ty(Ctx);
208*57006b14SXiang Li   case ParameterKind::I64:
209*57006b14SXiang Li     return Type::getInt64Ty(Ctx);
210*57006b14SXiang Li   case ParameterKind::OVERLOAD:
211*57006b14SXiang Li     return OverloadTy;
212*57006b14SXiang Li   case ParameterKind::RESOURCE_RET:
213*57006b14SXiang Li     return getResRetType(OverloadTy, Ctx);
214*57006b14SXiang Li   case ParameterKind::DXIL_HANDLE:
215*57006b14SXiang Li     return getHandleType(Ctx);
216*57006b14SXiang Li   default:
217*57006b14SXiang Li     break;
218*57006b14SXiang Li   }
219*57006b14SXiang Li   llvm_unreachable("Invalid parameter kind");
220*57006b14SXiang Li   return nullptr;
221*57006b14SXiang Li }
222*57006b14SXiang Li 
getDXILOpFunctionType(const OpCodeProperty * Prop,Type * OverloadTy)223*57006b14SXiang Li static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop,
224*57006b14SXiang Li                                            Type *OverloadTy) {
225*57006b14SXiang Li   SmallVector<Type *> ArgTys;
226*57006b14SXiang Li 
227*57006b14SXiang Li   auto ParamKinds = getOpCodeParameterKind(*Prop);
228*57006b14SXiang Li 
229*57006b14SXiang Li   for (unsigned I = 0; I < Prop->NumOfParameters; ++I) {
230*57006b14SXiang Li     ParameterKind Kind = ParamKinds[I];
231*57006b14SXiang Li     ArgTys.emplace_back(getTypeFromParameterKind(Kind, OverloadTy));
232*57006b14SXiang Li   }
233*57006b14SXiang Li   return FunctionType::get(
234*57006b14SXiang Li       ArgTys[0], ArrayRef<Type *>(&ArgTys[1], ArgTys.size() - 1), false);
235*57006b14SXiang Li }
236*57006b14SXiang Li 
getOrCreateDXILOpFunction(DXIL::OpCode DXILOp,Type * OverloadTy,Module & M)237*57006b14SXiang Li static FunctionCallee getOrCreateDXILOpFunction(DXIL::OpCode DXILOp,
238*57006b14SXiang Li                                                 Type *OverloadTy, Module &M) {
239*57006b14SXiang Li   const OpCodeProperty *Prop = getOpCodeProperty(DXILOp);
240*57006b14SXiang Li 
241*57006b14SXiang Li   OverloadKind Kind = getOverloadKind(OverloadTy);
242*57006b14SXiang Li   // FIXME: find the issue and report error in clang instead of check it in
243*57006b14SXiang Li   // backend.
244*57006b14SXiang Li   if ((Prop->OverloadTys & (uint16_t)Kind) == 0) {
245*57006b14SXiang Li     llvm_unreachable("invalid overload");
246*57006b14SXiang Li   }
247*57006b14SXiang Li 
248*57006b14SXiang Li   std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop);
249*57006b14SXiang Li   // Dependent on name to dedup.
250*57006b14SXiang Li   if (auto *Fn = M.getFunction(FnName))
251*57006b14SXiang Li     return FunctionCallee(Fn);
252*57006b14SXiang Li 
253*57006b14SXiang Li   FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, OverloadTy);
254*57006b14SXiang Li   return M.getOrInsertFunction(FnName, DXILOpFT);
255*57006b14SXiang Li }
256*57006b14SXiang Li 
257*57006b14SXiang Li namespace llvm {
258*57006b14SXiang Li namespace DXIL {
259*57006b14SXiang Li 
createDXILOpCall(DXIL::OpCode OpCode,Type * OverloadTy,llvm::iterator_range<Use * > Args)260*57006b14SXiang Li CallInst *DXILOpBuilder::createDXILOpCall(DXIL::OpCode OpCode, Type *OverloadTy,
261*57006b14SXiang Li                                           llvm::iterator_range<Use *> Args) {
262*57006b14SXiang Li   auto Fn = getOrCreateDXILOpFunction(OpCode, OverloadTy, M);
263*57006b14SXiang Li   SmallVector<Value *> FullArgs;
264*57006b14SXiang Li   FullArgs.emplace_back(B.getInt32((int32_t)OpCode));
265*57006b14SXiang Li   FullArgs.append(Args.begin(), Args.end());
266*57006b14SXiang Li   return B.CreateCall(Fn, FullArgs);
267*57006b14SXiang Li }
268*57006b14SXiang Li 
getOverloadTy(DXIL::OpCode OpCode,FunctionType * FT,bool NoOpCodeParam)269*57006b14SXiang Li Type *DXILOpBuilder::getOverloadTy(DXIL::OpCode OpCode, FunctionType *FT,
270*57006b14SXiang Li                                    bool NoOpCodeParam) {
271*57006b14SXiang Li 
272*57006b14SXiang Li   const OpCodeProperty *Prop = getOpCodeProperty(OpCode);
273*57006b14SXiang Li   if (Prop->OverloadParamIndex < 0) {
274*57006b14SXiang Li     auto &Ctx = FT->getContext();
275*57006b14SXiang Li     // When only has 1 overload type, just return it.
276*57006b14SXiang Li     switch (Prop->OverloadTys) {
277*57006b14SXiang Li     case OverloadKind::VOID:
278*57006b14SXiang Li       return Type::getVoidTy(Ctx);
279*57006b14SXiang Li     case OverloadKind::HALF:
280*57006b14SXiang Li       return Type::getHalfTy(Ctx);
281*57006b14SXiang Li     case OverloadKind::FLOAT:
282*57006b14SXiang Li       return Type::getFloatTy(Ctx);
283*57006b14SXiang Li     case OverloadKind::DOUBLE:
284*57006b14SXiang Li       return Type::getDoubleTy(Ctx);
285*57006b14SXiang Li     case OverloadKind::I1:
286*57006b14SXiang Li       return Type::getInt1Ty(Ctx);
287*57006b14SXiang Li     case OverloadKind::I8:
288*57006b14SXiang Li       return Type::getInt8Ty(Ctx);
289*57006b14SXiang Li     case OverloadKind::I16:
290*57006b14SXiang Li       return Type::getInt16Ty(Ctx);
291*57006b14SXiang Li     case OverloadKind::I32:
292*57006b14SXiang Li       return Type::getInt32Ty(Ctx);
293*57006b14SXiang Li     case OverloadKind::I64:
294*57006b14SXiang Li       return Type::getInt64Ty(Ctx);
295*57006b14SXiang Li     default:
296*57006b14SXiang Li       llvm_unreachable("invalid overload type");
297*57006b14SXiang Li       return nullptr;
298*57006b14SXiang Li     }
299*57006b14SXiang Li   }
300*57006b14SXiang Li 
301*57006b14SXiang Li   // Prop->OverloadParamIndex is 0, overload type is FT->getReturnType().
302*57006b14SXiang Li   Type *OverloadType = FT->getReturnType();
303*57006b14SXiang Li   if (Prop->OverloadParamIndex != 0) {
304*57006b14SXiang Li     // Skip Return Type and Type for DXIL opcode.
305*57006b14SXiang Li     const unsigned SkipedParam = NoOpCodeParam ? 2 : 1;
306*57006b14SXiang Li     OverloadType = FT->getParamType(Prop->OverloadParamIndex - SkipedParam);
307*57006b14SXiang Li   }
308*57006b14SXiang Li 
309*57006b14SXiang Li   auto ParamKinds = getOpCodeParameterKind(*Prop);
310*57006b14SXiang Li   auto Kind = ParamKinds[Prop->OverloadParamIndex];
311*57006b14SXiang Li   // For ResRet and CBufferRet, OverloadTy is in field of StructType.
312*57006b14SXiang Li   if (Kind == ParameterKind::CBUFFER_RET ||
313*57006b14SXiang Li       Kind == ParameterKind::RESOURCE_RET) {
314*57006b14SXiang Li     auto *ST = cast<StructType>(OverloadType);
315*57006b14SXiang Li     OverloadType = ST->getElementType(0);
316*57006b14SXiang Li   }
317*57006b14SXiang Li   return OverloadType;
318*57006b14SXiang Li }
319*57006b14SXiang Li 
getOpCodeName(DXIL::OpCode DXILOp)320*57006b14SXiang Li const char *DXILOpBuilder::getOpCodeName(DXIL::OpCode DXILOp) {
321*57006b14SXiang Li   return ::getOpCodeName(DXILOp);
322*57006b14SXiang Li }
323*57006b14SXiang Li } // namespace DXIL
324*57006b14SXiang Li } // namespace llvm
325