1 //===- DXILOpLower.cpp - Lowering LLVM intrinsic to DIXLOp function -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file This file contains passes and utilities to lower llvm intrinsic call 10 /// to DXILOp function call. 11 //===----------------------------------------------------------------------===// 12 13 #include "DXILConstants.h" 14 #include "DirectX.h" 15 #include "llvm/ADT/SmallVector.h" 16 #include "llvm/CodeGen/Passes.h" 17 #include "llvm/IR/IRBuilder.h" 18 #include "llvm/IR/Instruction.h" 19 #include "llvm/IR/Intrinsics.h" 20 #include "llvm/IR/Module.h" 21 #include "llvm/IR/PassManager.h" 22 #include "llvm/Pass.h" 23 #include "llvm/Support/ErrorHandling.h" 24 25 #define DEBUG_TYPE "dxil-op-lower" 26 27 using namespace llvm; 28 using namespace llvm::DXIL; 29 30 constexpr StringLiteral DXILOpNamePrefix = "dx.op."; 31 32 enum OverloadKind : uint16_t { 33 VOID = 1, 34 HALF = 1 << 1, 35 FLOAT = 1 << 2, 36 DOUBLE = 1 << 3, 37 I1 = 1 << 4, 38 I8 = 1 << 5, 39 I16 = 1 << 6, 40 I32 = 1 << 7, 41 I64 = 1 << 8, 42 UserDefineType = 1 << 9, 43 ObjectType = 1 << 10, 44 }; 45 46 static const char *getOverloadTypeName(OverloadKind Kind) { 47 switch (Kind) { 48 case OverloadKind::HALF: 49 return "f16"; 50 case OverloadKind::FLOAT: 51 return "f32"; 52 case OverloadKind::DOUBLE: 53 return "f64"; 54 case OverloadKind::I1: 55 return "i1"; 56 case OverloadKind::I8: 57 return "i8"; 58 case OverloadKind::I16: 59 return "i16"; 60 case OverloadKind::I32: 61 return "i32"; 62 case OverloadKind::I64: 63 return "i64"; 64 case OverloadKind::VOID: 65 case OverloadKind::ObjectType: 66 case OverloadKind::UserDefineType: 67 llvm_unreachable("invalid overload type for name"); 68 break; 69 } 70 } 71 72 static OverloadKind getOverloadKind(Type *Ty) { 73 Type::TypeID T = Ty->getTypeID(); 74 switch (T) { 75 case Type::VoidTyID: 76 return OverloadKind::VOID; 77 case Type::HalfTyID: 78 return OverloadKind::HALF; 79 case Type::FloatTyID: 80 return OverloadKind::FLOAT; 81 case Type::DoubleTyID: 82 return OverloadKind::DOUBLE; 83 case Type::IntegerTyID: { 84 IntegerType *ITy = cast<IntegerType>(Ty); 85 unsigned Bits = ITy->getBitWidth(); 86 switch (Bits) { 87 case 1: 88 return OverloadKind::I1; 89 case 8: 90 return OverloadKind::I8; 91 case 16: 92 return OverloadKind::I16; 93 case 32: 94 return OverloadKind::I32; 95 case 64: 96 return OverloadKind::I64; 97 default: 98 llvm_unreachable("invalid overload type"); 99 return OverloadKind::VOID; 100 } 101 } 102 case Type::PointerTyID: 103 return OverloadKind::UserDefineType; 104 case Type::StructTyID: 105 return OverloadKind::ObjectType; 106 default: 107 llvm_unreachable("invalid overload type"); 108 return OverloadKind::VOID; 109 } 110 } 111 112 static std::string getTypeName(OverloadKind Kind, Type *Ty) { 113 if (Kind < OverloadKind::UserDefineType) { 114 return getOverloadTypeName(Kind); 115 } else if (Kind == OverloadKind::UserDefineType) { 116 StructType *ST = cast<StructType>(Ty); 117 return ST->getStructName().str(); 118 } else if (Kind == OverloadKind::ObjectType) { 119 StructType *ST = cast<StructType>(Ty); 120 return ST->getStructName().str(); 121 } else { 122 std::string Str; 123 raw_string_ostream OS(Str); 124 Ty->print(OS); 125 return OS.str(); 126 } 127 } 128 129 // Static properties. 130 struct OpCodeProperty { 131 DXIL::OpCode OpCode; 132 // FIXME: change OpCodeName into index to a large string constant when move to 133 // tableGen. 134 const char *OpCodeName; 135 DXIL::OpCodeClass OpCodeClass; 136 uint16_t OverloadTys; 137 llvm::Attribute::AttrKind FuncAttr; 138 }; 139 140 static const char *getOpCodeClassName(const OpCodeProperty &Prop) { 141 // FIXME: generate this table with tableGen. 142 static const char *OpCodeClassNames[] = { 143 "unary", 144 }; 145 unsigned Index = static_cast<unsigned>(Prop.OpCodeClass); 146 assert(Index < (sizeof(OpCodeClassNames) / sizeof(OpCodeClassNames[0])) && 147 "Out of bound OpCodeClass"); 148 return OpCodeClassNames[Index]; 149 } 150 151 static std::string constructOverloadName(OverloadKind Kind, Type *Ty, 152 const OpCodeProperty &Prop) { 153 if (Kind == OverloadKind::VOID) { 154 return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str(); 155 } 156 return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." + 157 getTypeName(Kind, Ty)) 158 .str(); 159 } 160 161 static const OpCodeProperty *getOpCodeProperty(DXIL::OpCode DXILOp) { 162 // FIXME: generate this table with tableGen. 163 static const OpCodeProperty OpCodeProps[] = { 164 {DXIL::OpCode::Sin, "Sin", OpCodeClass::Unary, 165 OverloadKind::FLOAT | OverloadKind::HALF, Attribute::AttrKind::ReadNone}, 166 }; 167 // FIXME: change search to indexing with 168 // DXILOp once all DXIL op is added. 169 OpCodeProperty TmpProp; 170 TmpProp.OpCode = DXILOp; 171 const OpCodeProperty *Prop = 172 llvm::lower_bound(OpCodeProps, TmpProp, 173 [](const OpCodeProperty &A, const OpCodeProperty &B) { 174 return A.OpCode < B.OpCode; 175 }); 176 return Prop; 177 } 178 179 static FunctionCallee createDXILOpFunction(DXIL::OpCode DXILOp, Function &F, 180 Module &M) { 181 const OpCodeProperty *Prop = getOpCodeProperty(DXILOp); 182 183 // Get return type as overload type for DXILOp. 184 // Only simple mapping case here, so return type is good enough. 185 Type *OverloadTy = F.getReturnType(); 186 187 OverloadKind Kind = getOverloadKind(OverloadTy); 188 // FIXME: find the issue and report error in clang instead of check it in 189 // backend. 190 if ((Prop->OverloadTys & (uint16_t)Kind) == 0) { 191 llvm_unreachable("invalid overload"); 192 } 193 194 std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop); 195 assert(!M.getFunction(FnName) && "Function already exists"); 196 197 auto &Ctx = M.getContext(); 198 Type *OpCodeTy = Type::getInt32Ty(Ctx); 199 200 SmallVector<Type *> ArgTypes; 201 // DXIL has i32 opcode as first arg. 202 ArgTypes.emplace_back(OpCodeTy); 203 FunctionType *FT = F.getFunctionType(); 204 ArgTypes.append(FT->param_begin(), FT->param_end()); 205 FunctionType *DXILOpFT = FunctionType::get(OverloadTy, ArgTypes, false); 206 return M.getOrInsertFunction(FnName, DXILOpFT); 207 } 208 209 static void lowerIntrinsic(DXIL::OpCode DXILOp, Function &F, Module &M) { 210 auto DXILOpFn = createDXILOpFunction(DXILOp, F, M); 211 IRBuilder<> B(M.getContext()); 212 Value *DXILOpArg = B.getInt32(static_cast<unsigned>(DXILOp)); 213 for (User *U : make_early_inc_range(F.users())) { 214 CallInst *CI = dyn_cast<CallInst>(U); 215 if (!CI) 216 continue; 217 218 SmallVector<Value *> Args; 219 Args.emplace_back(DXILOpArg); 220 Args.append(CI->arg_begin(), CI->arg_end()); 221 B.SetInsertPoint(CI); 222 CallInst *DXILCI = B.CreateCall(DXILOpFn, Args); 223 CI->replaceAllUsesWith(DXILCI); 224 CI->eraseFromParent(); 225 } 226 if (F.user_empty()) 227 F.eraseFromParent(); 228 } 229 230 static bool lowerIntrinsics(Module &M) { 231 bool Updated = false; 232 static SmallDenseMap<Intrinsic::ID, DXIL::OpCode> LowerMap = { 233 {Intrinsic::sin, DXIL::OpCode::Sin}}; 234 for (Function &F : make_early_inc_range(M.functions())) { 235 if (!F.isDeclaration()) 236 continue; 237 Intrinsic::ID ID = F.getIntrinsicID(); 238 auto LowerIt = LowerMap.find(ID); 239 if (LowerIt == LowerMap.end()) 240 continue; 241 lowerIntrinsic(LowerIt->second, F, M); 242 Updated = true; 243 } 244 return Updated; 245 } 246 247 namespace { 248 /// A pass that transforms external global definitions into declarations. 249 class DXILOpLowering : public PassInfoMixin<DXILOpLowering> { 250 public: 251 PreservedAnalyses run(Module &M, ModuleAnalysisManager &) { 252 if (lowerIntrinsics(M)) 253 return PreservedAnalyses::none(); 254 return PreservedAnalyses::all(); 255 } 256 }; 257 } // namespace 258 259 namespace { 260 class DXILOpLoweringLegacy : public ModulePass { 261 public: 262 bool runOnModule(Module &M) override { return lowerIntrinsics(M); } 263 StringRef getPassName() const override { return "DXIL Op Lowering"; } 264 DXILOpLoweringLegacy() : ModulePass(ID) {} 265 266 static char ID; // Pass identification. 267 }; 268 char DXILOpLoweringLegacy::ID = 0; 269 270 } // end anonymous namespace 271 272 INITIALIZE_PASS_BEGIN(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", 273 false, false) 274 INITIALIZE_PASS_END(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", false, 275 false) 276 277 ModulePass *llvm::createDXILOpLoweringLegacyPass() { 278 return new DXILOpLoweringLegacy(); 279 } 280