1 //===- DXILOpLower.cpp - Lowering LLVM intrinsic to DIXLOp function -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file This file contains passes and utilities to lower llvm intrinsic call 10 /// to DXILOp function call. 11 //===----------------------------------------------------------------------===// 12 13 #include "DXILConstants.h" 14 #include "DirectX.h" 15 #include "llvm/ADT/SmallVector.h" 16 #include "llvm/CodeGen/Passes.h" 17 #include "llvm/IR/IRBuilder.h" 18 #include "llvm/IR/Instruction.h" 19 #include "llvm/IR/Intrinsics.h" 20 #include "llvm/IR/Module.h" 21 #include "llvm/IR/PassManager.h" 22 #include "llvm/Pass.h" 23 #include "llvm/Support/ErrorHandling.h" 24 25 #define DEBUG_TYPE "dxil-op-lower" 26 27 using namespace llvm; 28 using namespace llvm::DXIL; 29 30 constexpr StringLiteral DXILOpNamePrefix = "dx.op."; 31 32 enum OverloadKind : uint16_t { 33 VOID = 1, 34 HALF = 1 << 1, 35 FLOAT = 1 << 2, 36 DOUBLE = 1 << 3, 37 I1 = 1 << 4, 38 I8 = 1 << 5, 39 I16 = 1 << 6, 40 I32 = 1 << 7, 41 I64 = 1 << 8, 42 UserDefineType = 1 << 9, 43 ObjectType = 1 << 10, 44 }; 45 46 static const char *getOverloadTypeName(OverloadKind Kind) { 47 switch (Kind) { 48 case OverloadKind::HALF: 49 return "f16"; 50 case OverloadKind::FLOAT: 51 return "f32"; 52 case OverloadKind::DOUBLE: 53 return "f64"; 54 case OverloadKind::I1: 55 return "i1"; 56 case OverloadKind::I8: 57 return "i8"; 58 case OverloadKind::I16: 59 return "i16"; 60 case OverloadKind::I32: 61 return "i32"; 62 case OverloadKind::I64: 63 return "i64"; 64 case OverloadKind::VOID: 65 case OverloadKind::ObjectType: 66 case OverloadKind::UserDefineType: 67 break; 68 } 69 llvm_unreachable("invalid overload type for name"); 70 return "void"; 71 } 72 73 static OverloadKind getOverloadKind(Type *Ty) { 74 Type::TypeID T = Ty->getTypeID(); 75 switch (T) { 76 case Type::VoidTyID: 77 return OverloadKind::VOID; 78 case Type::HalfTyID: 79 return OverloadKind::HALF; 80 case Type::FloatTyID: 81 return OverloadKind::FLOAT; 82 case Type::DoubleTyID: 83 return OverloadKind::DOUBLE; 84 case Type::IntegerTyID: { 85 IntegerType *ITy = cast<IntegerType>(Ty); 86 unsigned Bits = ITy->getBitWidth(); 87 switch (Bits) { 88 case 1: 89 return OverloadKind::I1; 90 case 8: 91 return OverloadKind::I8; 92 case 16: 93 return OverloadKind::I16; 94 case 32: 95 return OverloadKind::I32; 96 case 64: 97 return OverloadKind::I64; 98 default: 99 llvm_unreachable("invalid overload type"); 100 return OverloadKind::VOID; 101 } 102 } 103 case Type::PointerTyID: 104 return OverloadKind::UserDefineType; 105 case Type::StructTyID: 106 return OverloadKind::ObjectType; 107 default: 108 llvm_unreachable("invalid overload type"); 109 return OverloadKind::VOID; 110 } 111 } 112 113 static std::string getTypeName(OverloadKind Kind, Type *Ty) { 114 if (Kind < OverloadKind::UserDefineType) { 115 return getOverloadTypeName(Kind); 116 } else if (Kind == OverloadKind::UserDefineType) { 117 StructType *ST = cast<StructType>(Ty); 118 return ST->getStructName().str(); 119 } else if (Kind == OverloadKind::ObjectType) { 120 StructType *ST = cast<StructType>(Ty); 121 return ST->getStructName().str(); 122 } else { 123 std::string Str; 124 raw_string_ostream OS(Str); 125 Ty->print(OS); 126 return OS.str(); 127 } 128 } 129 130 // Static properties. 131 struct OpCodeProperty { 132 DXIL::OpCode OpCode; 133 // Offset in DXILOpCodeNameTable. 134 unsigned OpCodeNameOffset; 135 DXIL::OpCodeClass OpCodeClass; 136 // Offset in DXILOpCodeClassNameTable. 137 unsigned OpCodeClassNameOffset; 138 uint16_t OverloadTys; 139 llvm::Attribute::AttrKind FuncAttr; 140 }; 141 142 // Include getOpCodeClassName getOpCodeProperty and getOpCodeName which 143 // generated by tableGen. 144 #define DXIL_OP_OPERATION_TABLE 145 #include "DXILOperation.inc" 146 #undef DXIL_OP_OPERATION_TABLE 147 148 static std::string constructOverloadName(OverloadKind Kind, Type *Ty, 149 const OpCodeProperty &Prop) { 150 if (Kind == OverloadKind::VOID) { 151 return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str(); 152 } 153 return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." + 154 getTypeName(Kind, Ty)) 155 .str(); 156 } 157 158 static FunctionCallee createDXILOpFunction(DXIL::OpCode DXILOp, Function &F, 159 Module &M) { 160 const OpCodeProperty *Prop = getOpCodeProperty(DXILOp); 161 162 // Get return type as overload type for DXILOp. 163 // Only simple mapping case here, so return type is good enough. 164 Type *OverloadTy = F.getReturnType(); 165 166 OverloadKind Kind = getOverloadKind(OverloadTy); 167 // FIXME: find the issue and report error in clang instead of check it in 168 // backend. 169 if ((Prop->OverloadTys & (uint16_t)Kind) == 0) { 170 llvm_unreachable("invalid overload"); 171 } 172 173 std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop); 174 assert(!M.getFunction(FnName) && "Function already exists"); 175 176 auto &Ctx = M.getContext(); 177 Type *OpCodeTy = Type::getInt32Ty(Ctx); 178 179 SmallVector<Type *> ArgTypes; 180 // DXIL has i32 opcode as first arg. 181 ArgTypes.emplace_back(OpCodeTy); 182 FunctionType *FT = F.getFunctionType(); 183 ArgTypes.append(FT->param_begin(), FT->param_end()); 184 FunctionType *DXILOpFT = FunctionType::get(OverloadTy, ArgTypes, false); 185 return M.getOrInsertFunction(FnName, DXILOpFT); 186 } 187 188 static void lowerIntrinsic(DXIL::OpCode DXILOp, Function &F, Module &M) { 189 auto DXILOpFn = createDXILOpFunction(DXILOp, F, M); 190 IRBuilder<> B(M.getContext()); 191 Value *DXILOpArg = B.getInt32(static_cast<unsigned>(DXILOp)); 192 for (User *U : make_early_inc_range(F.users())) { 193 CallInst *CI = dyn_cast<CallInst>(U); 194 if (!CI) 195 continue; 196 197 SmallVector<Value *> Args; 198 Args.emplace_back(DXILOpArg); 199 Args.append(CI->arg_begin(), CI->arg_end()); 200 B.SetInsertPoint(CI); 201 CallInst *DXILCI = B.CreateCall(DXILOpFn, Args); 202 LLVM_DEBUG(DXILCI->setName(getOpCodeName(DXILOp))); 203 CI->replaceAllUsesWith(DXILCI); 204 CI->eraseFromParent(); 205 } 206 if (F.user_empty()) 207 F.eraseFromParent(); 208 } 209 210 static bool lowerIntrinsics(Module &M) { 211 bool Updated = false; 212 213 #define DXIL_OP_INTRINSIC_MAP 214 #include "DXILOperation.inc" 215 #undef DXIL_OP_INTRINSIC_MAP 216 217 for (Function &F : make_early_inc_range(M.functions())) { 218 if (!F.isDeclaration()) 219 continue; 220 Intrinsic::ID ID = F.getIntrinsicID(); 221 if (ID == Intrinsic::not_intrinsic) 222 continue; 223 auto LowerIt = LowerMap.find(ID); 224 if (LowerIt == LowerMap.end()) 225 continue; 226 lowerIntrinsic(LowerIt->second, F, M); 227 Updated = true; 228 } 229 return Updated; 230 } 231 232 namespace { 233 /// A pass that transforms external global definitions into declarations. 234 class DXILOpLowering : public PassInfoMixin<DXILOpLowering> { 235 public: 236 PreservedAnalyses run(Module &M, ModuleAnalysisManager &) { 237 if (lowerIntrinsics(M)) 238 return PreservedAnalyses::none(); 239 return PreservedAnalyses::all(); 240 } 241 }; 242 } // namespace 243 244 namespace { 245 class DXILOpLoweringLegacy : public ModulePass { 246 public: 247 bool runOnModule(Module &M) override { return lowerIntrinsics(M); } 248 StringRef getPassName() const override { return "DXIL Op Lowering"; } 249 DXILOpLoweringLegacy() : ModulePass(ID) {} 250 251 static char ID; // Pass identification. 252 }; 253 char DXILOpLoweringLegacy::ID = 0; 254 255 } // end anonymous namespace 256 257 INITIALIZE_PASS_BEGIN(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", 258 false, false) 259 INITIALIZE_PASS_END(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", false, 260 false) 261 262 ModulePass *llvm::createDXILOpLoweringLegacyPass() { 263 return new DXILOpLoweringLegacy(); 264 } 265