1 //===- DXILOpLower.cpp - Lowering LLVM intrinsic to DIXLOp function -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file This file contains passes and utilities to lower llvm intrinsic call 10 /// to DXILOp function call. 11 //===----------------------------------------------------------------------===// 12 13 #include "DXILConstants.h" 14 #include "DirectX.h" 15 #include "llvm/ADT/SmallVector.h" 16 #include "llvm/CodeGen/Passes.h" 17 #include "llvm/IR/IRBuilder.h" 18 #include "llvm/IR/Instruction.h" 19 #include "llvm/IR/Intrinsics.h" 20 #include "llvm/IR/Module.h" 21 #include "llvm/IR/PassManager.h" 22 #include "llvm/Pass.h" 23 #include "llvm/Support/ErrorHandling.h" 24 25 #define DEBUG_TYPE "dxil-op-lower" 26 27 using namespace llvm; 28 using namespace llvm::DXIL; 29 30 constexpr StringLiteral DXILOpNamePrefix = "dx.op."; 31 32 enum OverloadKind : uint16_t { 33 VOID = 1, 34 HALF = 1 << 1, 35 FLOAT = 1 << 2, 36 DOUBLE = 1 << 3, 37 I1 = 1 << 4, 38 I8 = 1 << 5, 39 I16 = 1 << 6, 40 I32 = 1 << 7, 41 I64 = 1 << 8, 42 UserDefineType = 1 << 9, 43 ObjectType = 1 << 10, 44 }; 45 46 static const char *getOverloadTypeName(OverloadKind Kind) { 47 switch (Kind) { 48 case OverloadKind::HALF: 49 return "f16"; 50 case OverloadKind::FLOAT: 51 return "f32"; 52 case OverloadKind::DOUBLE: 53 return "f64"; 54 case OverloadKind::I1: 55 return "i1"; 56 case OverloadKind::I8: 57 return "i8"; 58 case OverloadKind::I16: 59 return "i16"; 60 case OverloadKind::I32: 61 return "i32"; 62 case OverloadKind::I64: 63 return "i64"; 64 case OverloadKind::VOID: 65 case OverloadKind::ObjectType: 66 case OverloadKind::UserDefineType: 67 llvm_unreachable("invalid overload type for name"); 68 return "void"; 69 } 70 } 71 72 static OverloadKind getOverloadKind(Type *Ty) { 73 Type::TypeID T = Ty->getTypeID(); 74 switch (T) { 75 case Type::VoidTyID: 76 return OverloadKind::VOID; 77 case Type::HalfTyID: 78 return OverloadKind::HALF; 79 case Type::FloatTyID: 80 return OverloadKind::FLOAT; 81 case Type::DoubleTyID: 82 return OverloadKind::DOUBLE; 83 case Type::IntegerTyID: { 84 IntegerType *ITy = cast<IntegerType>(Ty); 85 unsigned Bits = ITy->getBitWidth(); 86 switch (Bits) { 87 case 1: 88 return OverloadKind::I1; 89 case 8: 90 return OverloadKind::I8; 91 case 16: 92 return OverloadKind::I16; 93 case 32: 94 return OverloadKind::I32; 95 case 64: 96 return OverloadKind::I64; 97 default: 98 llvm_unreachable("invalid overload type"); 99 return OverloadKind::VOID; 100 } 101 } 102 case Type::PointerTyID: 103 return OverloadKind::UserDefineType; 104 case Type::StructTyID: 105 return OverloadKind::ObjectType; 106 default: 107 llvm_unreachable("invalid overload type"); 108 return OverloadKind::VOID; 109 } 110 } 111 112 static std::string getTypeName(OverloadKind Kind, Type *Ty) { 113 if (Kind < OverloadKind::UserDefineType) { 114 return getOverloadTypeName(Kind); 115 } else if (Kind == OverloadKind::UserDefineType) { 116 StructType *ST = cast<StructType>(Ty); 117 return ST->getStructName().str(); 118 } else if (Kind == OverloadKind::ObjectType) { 119 StructType *ST = cast<StructType>(Ty); 120 return ST->getStructName().str(); 121 } else { 122 std::string Str; 123 raw_string_ostream OS(Str); 124 Ty->print(OS); 125 return OS.str(); 126 } 127 } 128 129 // Static properties. 130 struct OpCodeProperty { 131 DXIL::OpCode OpCode; 132 // FIXME: change OpCodeName into index to a large string constant when move to 133 // tableGen. 134 const char *OpCodeName; 135 DXIL::OpCodeClass OpCodeClass; 136 uint16_t OverloadTys; 137 llvm::Attribute::AttrKind FuncAttr; 138 }; 139 140 static const char *getOpCodeClassName(const OpCodeProperty &Prop) { 141 // FIXME: generate this table with tableGen. 142 static const char *OpCodeClassNames[] = { 143 "binary", 144 "unary", 145 }; 146 unsigned Index = static_cast<unsigned>(Prop.OpCodeClass); 147 assert(Index < (sizeof(OpCodeClassNames) / sizeof(OpCodeClassNames[0])) && 148 "Out of bound OpCodeClass"); 149 return OpCodeClassNames[Index]; 150 } 151 152 static std::string constructOverloadName(OverloadKind Kind, Type *Ty, 153 const OpCodeProperty &Prop) { 154 if (Kind == OverloadKind::VOID) { 155 return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str(); 156 } 157 return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." + 158 getTypeName(Kind, Ty)) 159 .str(); 160 } 161 162 static const OpCodeProperty *getOpCodeProperty(DXIL::OpCode DXILOp) { 163 // FIXME: generate this table with tableGen. 164 static const OpCodeProperty OpCodeProps[] = { 165 {DXIL::OpCode::Sin, "Sin", OpCodeClass::Unary, 166 OverloadKind::FLOAT | OverloadKind::HALF, Attribute::AttrKind::ReadNone}, 167 {DXIL::OpCode::UMax, "UMax", OpCodeClass::Binary, 168 OverloadKind::I16 | OverloadKind::I32 | OverloadKind::I64, 169 Attribute::AttrKind::ReadNone}, 170 }; 171 // FIXME: change search to indexing with 172 // DXILOp once all DXIL op is added. 173 OpCodeProperty TmpProp; 174 TmpProp.OpCode = DXILOp; 175 const OpCodeProperty *Prop = 176 llvm::lower_bound(OpCodeProps, TmpProp, 177 [](const OpCodeProperty &A, const OpCodeProperty &B) { 178 return A.OpCode < B.OpCode; 179 }); 180 return Prop; 181 } 182 183 static FunctionCallee createDXILOpFunction(DXIL::OpCode DXILOp, Function &F, 184 Module &M) { 185 const OpCodeProperty *Prop = getOpCodeProperty(DXILOp); 186 187 // Get return type as overload type for DXILOp. 188 // Only simple mapping case here, so return type is good enough. 189 Type *OverloadTy = F.getReturnType(); 190 191 OverloadKind Kind = getOverloadKind(OverloadTy); 192 // FIXME: find the issue and report error in clang instead of check it in 193 // backend. 194 if ((Prop->OverloadTys & (uint16_t)Kind) == 0) { 195 llvm_unreachable("invalid overload"); 196 } 197 198 std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop); 199 assert(!M.getFunction(FnName) && "Function already exists"); 200 201 auto &Ctx = M.getContext(); 202 Type *OpCodeTy = Type::getInt32Ty(Ctx); 203 204 SmallVector<Type *> ArgTypes; 205 // DXIL has i32 opcode as first arg. 206 ArgTypes.emplace_back(OpCodeTy); 207 FunctionType *FT = F.getFunctionType(); 208 ArgTypes.append(FT->param_begin(), FT->param_end()); 209 FunctionType *DXILOpFT = FunctionType::get(OverloadTy, ArgTypes, false); 210 return M.getOrInsertFunction(FnName, DXILOpFT); 211 } 212 213 static void lowerIntrinsic(DXIL::OpCode DXILOp, Function &F, Module &M) { 214 auto DXILOpFn = createDXILOpFunction(DXILOp, F, M); 215 IRBuilder<> B(M.getContext()); 216 Value *DXILOpArg = B.getInt32(static_cast<unsigned>(DXILOp)); 217 for (User *U : make_early_inc_range(F.users())) { 218 CallInst *CI = dyn_cast<CallInst>(U); 219 if (!CI) 220 continue; 221 222 SmallVector<Value *> Args; 223 Args.emplace_back(DXILOpArg); 224 Args.append(CI->arg_begin(), CI->arg_end()); 225 B.SetInsertPoint(CI); 226 CallInst *DXILCI = B.CreateCall(DXILOpFn, Args); 227 CI->replaceAllUsesWith(DXILCI); 228 CI->eraseFromParent(); 229 } 230 if (F.user_empty()) 231 F.eraseFromParent(); 232 } 233 234 static bool lowerIntrinsics(Module &M) { 235 bool Updated = false; 236 static SmallDenseMap<Intrinsic::ID, DXIL::OpCode> LowerMap = { 237 {Intrinsic::sin, DXIL::OpCode::Sin}, 238 {Intrinsic::umax, DXIL::OpCode::UMax}}; 239 for (Function &F : make_early_inc_range(M.functions())) { 240 if (!F.isDeclaration()) 241 continue; 242 Intrinsic::ID ID = F.getIntrinsicID(); 243 auto LowerIt = LowerMap.find(ID); 244 if (LowerIt == LowerMap.end()) 245 continue; 246 lowerIntrinsic(LowerIt->second, F, M); 247 Updated = true; 248 } 249 return Updated; 250 } 251 252 namespace { 253 /// A pass that transforms external global definitions into declarations. 254 class DXILOpLowering : public PassInfoMixin<DXILOpLowering> { 255 public: 256 PreservedAnalyses run(Module &M, ModuleAnalysisManager &) { 257 if (lowerIntrinsics(M)) 258 return PreservedAnalyses::none(); 259 return PreservedAnalyses::all(); 260 } 261 }; 262 } // namespace 263 264 namespace { 265 class DXILOpLoweringLegacy : public ModulePass { 266 public: 267 bool runOnModule(Module &M) override { return lowerIntrinsics(M); } 268 StringRef getPassName() const override { return "DXIL Op Lowering"; } 269 DXILOpLoweringLegacy() : ModulePass(ID) {} 270 271 static char ID; // Pass identification. 272 }; 273 char DXILOpLoweringLegacy::ID = 0; 274 275 } // end anonymous namespace 276 277 INITIALIZE_PASS_BEGIN(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", 278 false, false) 279 INITIALIZE_PASS_END(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", false, 280 false) 281 282 ModulePass *llvm::createDXILOpLoweringLegacyPass() { 283 return new DXILOpLoweringLegacy(); 284 } 285