1 //===- DXILOpLower.cpp - Lowering LLVM intrinsic to DIXLOp function -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file This file contains passes and utilities to lower llvm intrinsic call 10 /// to DXILOp function call. 11 //===----------------------------------------------------------------------===// 12 13 #include "DXILConstants.h" 14 #include "DirectX.h" 15 #include "llvm/ADT/SmallVector.h" 16 #include "llvm/CodeGen/Passes.h" 17 #include "llvm/IR/IRBuilder.h" 18 #include "llvm/IR/Instruction.h" 19 #include "llvm/IR/Intrinsics.h" 20 #include "llvm/IR/Module.h" 21 #include "llvm/IR/PassManager.h" 22 #include "llvm/Pass.h" 23 #include "llvm/Support/ErrorHandling.h" 24 25 #define DEBUG_TYPE "dxil-op-lower" 26 27 using namespace llvm; 28 using namespace llvm::DXIL; 29 30 constexpr StringLiteral DXILOpNamePrefix = "dx.op."; 31 32 enum OverloadKind : uint16_t { 33 VOID = 1, 34 HALF = 1 << 1, 35 FLOAT = 1 << 2, 36 DOUBLE = 1 << 3, 37 I1 = 1 << 4, 38 I8 = 1 << 5, 39 I16 = 1 << 6, 40 I32 = 1 << 7, 41 I64 = 1 << 8, 42 UserDefineType = 1 << 9, 43 ObjectType = 1 << 10, 44 }; 45 46 static const char *getOverloadTypeName(OverloadKind Kind) { 47 switch (Kind) { 48 case OverloadKind::HALF: 49 return "f16"; 50 case OverloadKind::FLOAT: 51 return "f32"; 52 case OverloadKind::DOUBLE: 53 return "f64"; 54 case OverloadKind::I1: 55 return "i1"; 56 case OverloadKind::I8: 57 return "i8"; 58 case OverloadKind::I16: 59 return "i16"; 60 case OverloadKind::I32: 61 return "i32"; 62 case OverloadKind::I64: 63 return "i64"; 64 case OverloadKind::VOID: 65 case OverloadKind::ObjectType: 66 case OverloadKind::UserDefineType: 67 break; 68 } 69 llvm_unreachable("invalid overload type for name"); 70 return "void"; 71 } 72 73 static OverloadKind getOverloadKind(Type *Ty) { 74 Type::TypeID T = Ty->getTypeID(); 75 switch (T) { 76 case Type::VoidTyID: 77 return OverloadKind::VOID; 78 case Type::HalfTyID: 79 return OverloadKind::HALF; 80 case Type::FloatTyID: 81 return OverloadKind::FLOAT; 82 case Type::DoubleTyID: 83 return OverloadKind::DOUBLE; 84 case Type::IntegerTyID: { 85 IntegerType *ITy = cast<IntegerType>(Ty); 86 unsigned Bits = ITy->getBitWidth(); 87 switch (Bits) { 88 case 1: 89 return OverloadKind::I1; 90 case 8: 91 return OverloadKind::I8; 92 case 16: 93 return OverloadKind::I16; 94 case 32: 95 return OverloadKind::I32; 96 case 64: 97 return OverloadKind::I64; 98 default: 99 llvm_unreachable("invalid overload type"); 100 return OverloadKind::VOID; 101 } 102 } 103 case Type::PointerTyID: 104 return OverloadKind::UserDefineType; 105 case Type::StructTyID: 106 return OverloadKind::ObjectType; 107 default: 108 llvm_unreachable("invalid overload type"); 109 return OverloadKind::VOID; 110 } 111 } 112 113 static std::string getTypeName(OverloadKind Kind, Type *Ty) { 114 if (Kind < OverloadKind::UserDefineType) { 115 return getOverloadTypeName(Kind); 116 } else if (Kind == OverloadKind::UserDefineType) { 117 StructType *ST = cast<StructType>(Ty); 118 return ST->getStructName().str(); 119 } else if (Kind == OverloadKind::ObjectType) { 120 StructType *ST = cast<StructType>(Ty); 121 return ST->getStructName().str(); 122 } else { 123 std::string Str; 124 raw_string_ostream OS(Str); 125 Ty->print(OS); 126 return OS.str(); 127 } 128 } 129 130 // Static properties. 131 struct OpCodeProperty { 132 DXIL::OpCode OpCode; 133 // FIXME: change OpCodeName into index to a large string constant when move to 134 // tableGen. 135 const char *OpCodeName; 136 DXIL::OpCodeClass OpCodeClass; 137 uint16_t OverloadTys; 138 llvm::Attribute::AttrKind FuncAttr; 139 }; 140 141 static const char *getOpCodeClassName(const OpCodeProperty &Prop) { 142 // FIXME: generate this table with tableGen. 143 static const char *OpCodeClassNames[] = { 144 "binary", 145 "unary", 146 }; 147 unsigned Index = static_cast<unsigned>(Prop.OpCodeClass); 148 assert(Index < (sizeof(OpCodeClassNames) / sizeof(OpCodeClassNames[0])) && 149 "Out of bound OpCodeClass"); 150 return OpCodeClassNames[Index]; 151 } 152 153 static std::string constructOverloadName(OverloadKind Kind, Type *Ty, 154 const OpCodeProperty &Prop) { 155 if (Kind == OverloadKind::VOID) { 156 return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str(); 157 } 158 return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." + 159 getTypeName(Kind, Ty)) 160 .str(); 161 } 162 163 static const OpCodeProperty *getOpCodeProperty(DXIL::OpCode DXILOp) { 164 // FIXME: generate this table with tableGen. 165 static const OpCodeProperty OpCodeProps[] = { 166 {DXIL::OpCode::Sin, "Sin", OpCodeClass::Unary, 167 OverloadKind::FLOAT | OverloadKind::HALF, Attribute::AttrKind::ReadNone}, 168 {DXIL::OpCode::UMax, "UMax", OpCodeClass::Binary, 169 OverloadKind::I16 | OverloadKind::I32 | OverloadKind::I64, 170 Attribute::AttrKind::ReadNone}, 171 }; 172 // FIXME: change search to indexing with 173 // DXILOp once all DXIL op is added. 174 OpCodeProperty TmpProp; 175 TmpProp.OpCode = DXILOp; 176 const OpCodeProperty *Prop = 177 llvm::lower_bound(OpCodeProps, TmpProp, 178 [](const OpCodeProperty &A, const OpCodeProperty &B) { 179 return A.OpCode < B.OpCode; 180 }); 181 return Prop; 182 } 183 184 static FunctionCallee createDXILOpFunction(DXIL::OpCode DXILOp, Function &F, 185 Module &M) { 186 const OpCodeProperty *Prop = getOpCodeProperty(DXILOp); 187 188 // Get return type as overload type for DXILOp. 189 // Only simple mapping case here, so return type is good enough. 190 Type *OverloadTy = F.getReturnType(); 191 192 OverloadKind Kind = getOverloadKind(OverloadTy); 193 // FIXME: find the issue and report error in clang instead of check it in 194 // backend. 195 if ((Prop->OverloadTys & (uint16_t)Kind) == 0) { 196 llvm_unreachable("invalid overload"); 197 } 198 199 std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop); 200 assert(!M.getFunction(FnName) && "Function already exists"); 201 202 auto &Ctx = M.getContext(); 203 Type *OpCodeTy = Type::getInt32Ty(Ctx); 204 205 SmallVector<Type *> ArgTypes; 206 // DXIL has i32 opcode as first arg. 207 ArgTypes.emplace_back(OpCodeTy); 208 FunctionType *FT = F.getFunctionType(); 209 ArgTypes.append(FT->param_begin(), FT->param_end()); 210 FunctionType *DXILOpFT = FunctionType::get(OverloadTy, ArgTypes, false); 211 return M.getOrInsertFunction(FnName, DXILOpFT); 212 } 213 214 static void lowerIntrinsic(DXIL::OpCode DXILOp, Function &F, Module &M) { 215 auto DXILOpFn = createDXILOpFunction(DXILOp, F, M); 216 IRBuilder<> B(M.getContext()); 217 Value *DXILOpArg = B.getInt32(static_cast<unsigned>(DXILOp)); 218 for (User *U : make_early_inc_range(F.users())) { 219 CallInst *CI = dyn_cast<CallInst>(U); 220 if (!CI) 221 continue; 222 223 SmallVector<Value *> Args; 224 Args.emplace_back(DXILOpArg); 225 Args.append(CI->arg_begin(), CI->arg_end()); 226 B.SetInsertPoint(CI); 227 CallInst *DXILCI = B.CreateCall(DXILOpFn, Args); 228 CI->replaceAllUsesWith(DXILCI); 229 CI->eraseFromParent(); 230 } 231 if (F.user_empty()) 232 F.eraseFromParent(); 233 } 234 235 static bool lowerIntrinsics(Module &M) { 236 bool Updated = false; 237 238 #define DXIL_OP_INTRINSIC_MAP 239 #include "DXILOperation.inc" 240 #undef DXIL_OP_INTRINSIC_MAP 241 242 for (Function &F : make_early_inc_range(M.functions())) { 243 if (!F.isDeclaration()) 244 continue; 245 Intrinsic::ID ID = F.getIntrinsicID(); 246 if (ID == Intrinsic::not_intrinsic) 247 continue; 248 auto LowerIt = LowerMap.find(ID); 249 if (LowerIt == LowerMap.end()) 250 continue; 251 lowerIntrinsic(LowerIt->second, F, M); 252 Updated = true; 253 } 254 return Updated; 255 } 256 257 namespace { 258 /// A pass that transforms external global definitions into declarations. 259 class DXILOpLowering : public PassInfoMixin<DXILOpLowering> { 260 public: 261 PreservedAnalyses run(Module &M, ModuleAnalysisManager &) { 262 if (lowerIntrinsics(M)) 263 return PreservedAnalyses::none(); 264 return PreservedAnalyses::all(); 265 } 266 }; 267 } // namespace 268 269 namespace { 270 class DXILOpLoweringLegacy : public ModulePass { 271 public: 272 bool runOnModule(Module &M) override { return lowerIntrinsics(M); } 273 StringRef getPassName() const override { return "DXIL Op Lowering"; } 274 DXILOpLoweringLegacy() : ModulePass(ID) {} 275 276 static char ID; // Pass identification. 277 }; 278 char DXILOpLoweringLegacy::ID = 0; 279 280 } // end anonymous namespace 281 282 INITIALIZE_PASS_BEGIN(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", 283 false, false) 284 INITIALIZE_PASS_END(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", false, 285 false) 286 287 ModulePass *llvm::createDXILOpLoweringLegacyPass() { 288 return new DXILOpLoweringLegacy(); 289 } 290