1 //===- DXILOpLower.cpp - Lowering LLVM intrinsic to DIXLOp function -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file This file contains passes and utilities to lower llvm intrinsic call
10 /// to DXILOp function call.
11 //===----------------------------------------------------------------------===//
12 
13 #include "DXILConstants.h"
14 #include "DirectX.h"
15 #include "llvm/ADT/SmallVector.h"
16 #include "llvm/CodeGen/Passes.h"
17 #include "llvm/IR/IRBuilder.h"
18 #include "llvm/IR/Instruction.h"
19 #include "llvm/IR/Intrinsics.h"
20 #include "llvm/IR/Module.h"
21 #include "llvm/IR/PassManager.h"
22 #include "llvm/Pass.h"
23 #include "llvm/Support/ErrorHandling.h"
24 
25 #define DEBUG_TYPE "dxil-op-lower"
26 
27 using namespace llvm;
28 using namespace llvm::DXIL;
29 
30 constexpr StringLiteral DXILOpNamePrefix = "dx.op.";
31 
32 enum OverloadKind : uint16_t {
33   VOID = 1,
34   HALF = 1 << 1,
35   FLOAT = 1 << 2,
36   DOUBLE = 1 << 3,
37   I1 = 1 << 4,
38   I8 = 1 << 5,
39   I16 = 1 << 6,
40   I32 = 1 << 7,
41   I64 = 1 << 8,
42   UserDefineType = 1 << 9,
43   ObjectType = 1 << 10,
44 };
45 
46 static const char *getOverloadTypeName(OverloadKind Kind) {
47   switch (Kind) {
48   case OverloadKind::HALF:
49     return "f16";
50   case OverloadKind::FLOAT:
51     return "f32";
52   case OverloadKind::DOUBLE:
53     return "f64";
54   case OverloadKind::I1:
55     return "i1";
56   case OverloadKind::I8:
57     return "i8";
58   case OverloadKind::I16:
59     return "i16";
60   case OverloadKind::I32:
61     return "i32";
62   case OverloadKind::I64:
63     return "i64";
64   case OverloadKind::VOID:
65   case OverloadKind::ObjectType:
66   case OverloadKind::UserDefineType:
67     llvm_unreachable("invalid overload type for name");
68     return "void";
69   }
70 }
71 
72 static OverloadKind getOverloadKind(Type *Ty) {
73   Type::TypeID T = Ty->getTypeID();
74   switch (T) {
75   case Type::VoidTyID:
76     return OverloadKind::VOID;
77   case Type::HalfTyID:
78     return OverloadKind::HALF;
79   case Type::FloatTyID:
80     return OverloadKind::FLOAT;
81   case Type::DoubleTyID:
82     return OverloadKind::DOUBLE;
83   case Type::IntegerTyID: {
84     IntegerType *ITy = cast<IntegerType>(Ty);
85     unsigned Bits = ITy->getBitWidth();
86     switch (Bits) {
87     case 1:
88       return OverloadKind::I1;
89     case 8:
90       return OverloadKind::I8;
91     case 16:
92       return OverloadKind::I16;
93     case 32:
94       return OverloadKind::I32;
95     case 64:
96       return OverloadKind::I64;
97     default:
98       llvm_unreachable("invalid overload type");
99       return OverloadKind::VOID;
100     }
101   }
102   case Type::PointerTyID:
103     return OverloadKind::UserDefineType;
104   case Type::StructTyID:
105     return OverloadKind::ObjectType;
106   default:
107     llvm_unreachable("invalid overload type");
108     return OverloadKind::VOID;
109   }
110 }
111 
112 static std::string getTypeName(OverloadKind Kind, Type *Ty) {
113   if (Kind < OverloadKind::UserDefineType) {
114     return getOverloadTypeName(Kind);
115   } else if (Kind == OverloadKind::UserDefineType) {
116     StructType *ST = cast<StructType>(Ty);
117     return ST->getStructName().str();
118   } else if (Kind == OverloadKind::ObjectType) {
119     StructType *ST = cast<StructType>(Ty);
120     return ST->getStructName().str();
121   } else {
122     std::string Str;
123     raw_string_ostream OS(Str);
124     Ty->print(OS);
125     return OS.str();
126   }
127 }
128 
129 // Static properties.
130 struct OpCodeProperty {
131   DXIL::OpCode OpCode;
132   // FIXME: change OpCodeName into index to a large string constant when move to
133   // tableGen.
134   const char *OpCodeName;
135   DXIL::OpCodeClass OpCodeClass;
136   uint16_t OverloadTys;
137   llvm::Attribute::AttrKind FuncAttr;
138 };
139 
140 static const char *getOpCodeClassName(const OpCodeProperty &Prop) {
141   // FIXME: generate this table with tableGen.
142   static const char *OpCodeClassNames[] = {
143       "binary",
144       "unary",
145   };
146   unsigned Index = static_cast<unsigned>(Prop.OpCodeClass);
147   assert(Index < (sizeof(OpCodeClassNames) / sizeof(OpCodeClassNames[0])) &&
148          "Out of bound OpCodeClass");
149   return OpCodeClassNames[Index];
150 }
151 
152 static std::string constructOverloadName(OverloadKind Kind, Type *Ty,
153                                          const OpCodeProperty &Prop) {
154   if (Kind == OverloadKind::VOID) {
155     return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str();
156   }
157   return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." +
158           getTypeName(Kind, Ty))
159       .str();
160 }
161 
162 static const OpCodeProperty *getOpCodeProperty(DXIL::OpCode DXILOp) {
163   // FIXME: generate this table with tableGen.
164   static const OpCodeProperty OpCodeProps[] = {
165       {DXIL::OpCode::Sin, "Sin", OpCodeClass::Unary,
166        OverloadKind::FLOAT | OverloadKind::HALF, Attribute::AttrKind::ReadNone},
167       {DXIL::OpCode::UMax, "UMax", OpCodeClass::Binary,
168        OverloadKind::I16 | OverloadKind::I32 | OverloadKind::I64,
169        Attribute::AttrKind::ReadNone},
170   };
171   // FIXME: change search to indexing with
172   // DXILOp once all DXIL op is added.
173   OpCodeProperty TmpProp;
174   TmpProp.OpCode = DXILOp;
175   const OpCodeProperty *Prop =
176       llvm::lower_bound(OpCodeProps, TmpProp,
177                         [](const OpCodeProperty &A, const OpCodeProperty &B) {
178                           return A.OpCode < B.OpCode;
179                         });
180   return Prop;
181 }
182 
183 static FunctionCallee createDXILOpFunction(DXIL::OpCode DXILOp, Function &F,
184                                            Module &M) {
185   const OpCodeProperty *Prop = getOpCodeProperty(DXILOp);
186 
187   // Get return type as overload type for DXILOp.
188   // Only simple mapping case here, so return type is good enough.
189   Type *OverloadTy = F.getReturnType();
190 
191   OverloadKind Kind = getOverloadKind(OverloadTy);
192   // FIXME: find the issue and report error in clang instead of check it in
193   // backend.
194   if ((Prop->OverloadTys & (uint16_t)Kind) == 0) {
195     llvm_unreachable("invalid overload");
196   }
197 
198   std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop);
199   assert(!M.getFunction(FnName) && "Function already exists");
200 
201   auto &Ctx = M.getContext();
202   Type *OpCodeTy = Type::getInt32Ty(Ctx);
203 
204   SmallVector<Type *> ArgTypes;
205   // DXIL has i32 opcode as first arg.
206   ArgTypes.emplace_back(OpCodeTy);
207   FunctionType *FT = F.getFunctionType();
208   ArgTypes.append(FT->param_begin(), FT->param_end());
209   FunctionType *DXILOpFT = FunctionType::get(OverloadTy, ArgTypes, false);
210   return M.getOrInsertFunction(FnName, DXILOpFT);
211 }
212 
213 static void lowerIntrinsic(DXIL::OpCode DXILOp, Function &F, Module &M) {
214   auto DXILOpFn = createDXILOpFunction(DXILOp, F, M);
215   IRBuilder<> B(M.getContext());
216   Value *DXILOpArg = B.getInt32(static_cast<unsigned>(DXILOp));
217   for (User *U : make_early_inc_range(F.users())) {
218     CallInst *CI = dyn_cast<CallInst>(U);
219     if (!CI)
220       continue;
221 
222     SmallVector<Value *> Args;
223     Args.emplace_back(DXILOpArg);
224     Args.append(CI->arg_begin(), CI->arg_end());
225     B.SetInsertPoint(CI);
226     CallInst *DXILCI = B.CreateCall(DXILOpFn, Args);
227     CI->replaceAllUsesWith(DXILCI);
228     CI->eraseFromParent();
229   }
230   if (F.user_empty())
231     F.eraseFromParent();
232 }
233 
234 static bool lowerIntrinsics(Module &M) {
235   bool Updated = false;
236   static SmallDenseMap<Intrinsic::ID, DXIL::OpCode> LowerMap = {
237       {Intrinsic::sin, DXIL::OpCode::Sin},
238       {Intrinsic::umax, DXIL::OpCode::UMax}};
239   for (Function &F : make_early_inc_range(M.functions())) {
240     if (!F.isDeclaration())
241       continue;
242     Intrinsic::ID ID = F.getIntrinsicID();
243     auto LowerIt = LowerMap.find(ID);
244     if (LowerIt == LowerMap.end())
245       continue;
246     lowerIntrinsic(LowerIt->second, F, M);
247     Updated = true;
248   }
249   return Updated;
250 }
251 
252 namespace {
253 /// A pass that transforms external global definitions into declarations.
254 class DXILOpLowering : public PassInfoMixin<DXILOpLowering> {
255 public:
256   PreservedAnalyses run(Module &M, ModuleAnalysisManager &) {
257     if (lowerIntrinsics(M))
258       return PreservedAnalyses::none();
259     return PreservedAnalyses::all();
260   }
261 };
262 } // namespace
263 
264 namespace {
265 class DXILOpLoweringLegacy : public ModulePass {
266 public:
267   bool runOnModule(Module &M) override { return lowerIntrinsics(M); }
268   StringRef getPassName() const override { return "DXIL Op Lowering"; }
269   DXILOpLoweringLegacy() : ModulePass(ID) {}
270 
271   static char ID; // Pass identification.
272 };
273 char DXILOpLoweringLegacy::ID = 0;
274 
275 } // end anonymous namespace
276 
277 INITIALIZE_PASS_BEGIN(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering",
278                       false, false)
279 INITIALIZE_PASS_END(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", false,
280                     false)
281 
282 ModulePass *llvm::createDXILOpLoweringLegacyPass() {
283   return new DXILOpLoweringLegacy();
284 }
285