1 //===-- SPIRVGlobalRegistry.cpp - SPIR-V Global Registry --------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains the implementation of the SPIRVGlobalRegistry class, 10 // which is used to maintain rich type information required for SPIR-V even 11 // after lowering from LLVM IR to GMIR. It can convert an llvm::Type into 12 // an OpTypeXXX instruction, and map it to a virtual register. Also it builds 13 // and supports consistency of constants and global variables. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "SPIRVGlobalRegistry.h" 18 #include "SPIRV.h" 19 #include "SPIRVSubtarget.h" 20 #include "SPIRVTargetMachine.h" 21 #include "SPIRVUtils.h" 22 23 using namespace llvm; 24 SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize) 25 : PointerSize(PointerSize) {} 26 27 SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg( 28 const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder, 29 SPIRV::AccessQualifier AccessQual, bool EmitIR) { 30 31 SPIRVType *SpirvType = 32 getOrCreateSPIRVType(Type, MIRBuilder, AccessQual, EmitIR); 33 assignSPIRVTypeToVReg(SpirvType, VReg, MIRBuilder); 34 return SpirvType; 35 } 36 37 void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType *SpirvType, 38 Register VReg, 39 MachineIRBuilder &MIRBuilder) { 40 VRegToTypeMap[&MIRBuilder.getMF()][VReg] = SpirvType; 41 } 42 43 static Register createTypeVReg(MachineIRBuilder &MIRBuilder) { 44 auto &MRI = MIRBuilder.getMF().getRegInfo(); 45 auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32)); 46 MRI.setRegClass(Res, &SPIRV::TYPERegClass); 47 return Res; 48 } 49 50 static Register createTypeVReg(MachineRegisterInfo &MRI) { 51 auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32)); 52 MRI.setRegClass(Res, &SPIRV::TYPERegClass); 53 return Res; 54 } 55 56 SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) { 57 return MIRBuilder.buildInstr(SPIRV::OpTypeBool) 58 .addDef(createTypeVReg(MIRBuilder)); 59 } 60 61 SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width, 62 MachineIRBuilder &MIRBuilder, 63 bool IsSigned) { 64 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeInt) 65 .addDef(createTypeVReg(MIRBuilder)) 66 .addImm(Width) 67 .addImm(IsSigned ? 1 : 0); 68 return MIB; 69 } 70 71 SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width, 72 MachineIRBuilder &MIRBuilder) { 73 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFloat) 74 .addDef(createTypeVReg(MIRBuilder)) 75 .addImm(Width); 76 return MIB; 77 } 78 79 SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) { 80 return MIRBuilder.buildInstr(SPIRV::OpTypeVoid) 81 .addDef(createTypeVReg(MIRBuilder)); 82 } 83 84 SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems, 85 SPIRVType *ElemType, 86 MachineIRBuilder &MIRBuilder) { 87 auto EleOpc = ElemType->getOpcode(); 88 assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat || 89 EleOpc == SPIRV::OpTypeBool) && 90 "Invalid vector element type"); 91 92 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeVector) 93 .addDef(createTypeVReg(MIRBuilder)) 94 .addUse(getSPIRVTypeID(ElemType)) 95 .addImm(NumElems); 96 return MIB; 97 } 98 99 Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val, 100 MachineIRBuilder &MIRBuilder, 101 SPIRVType *SpvType, 102 bool EmitIR) { 103 auto &MF = MIRBuilder.getMF(); 104 Register Res; 105 const IntegerType *LLVMIntTy; 106 if (SpvType) 107 LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType)); 108 else 109 LLVMIntTy = IntegerType::getInt32Ty(MF.getFunction().getContext()); 110 // Find a constant in DT or build a new one. 111 const auto ConstInt = 112 ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val); 113 unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32; 114 Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth)); 115 assignTypeToVReg(LLVMIntTy, Res, MIRBuilder); 116 if (EmitIR) 117 MIRBuilder.buildConstant(Res, *ConstInt); 118 else 119 MIRBuilder.buildInstr(SPIRV::OpConstantI) 120 .addDef(Res) 121 .addImm(ConstInt->getSExtValue()); 122 return Res; 123 } 124 125 Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val, 126 MachineIRBuilder &MIRBuilder, 127 SPIRVType *SpvType) { 128 auto &MF = MIRBuilder.getMF(); 129 Register Res; 130 const Type *LLVMFPTy; 131 if (SpvType) { 132 LLVMFPTy = getTypeForSPIRVType(SpvType); 133 assert(LLVMFPTy->isFloatingPointTy()); 134 } else { 135 LLVMFPTy = IntegerType::getFloatTy(MF.getFunction().getContext()); 136 } 137 // Find a constant in DT or build a new one. 138 const auto ConstFP = ConstantFP::get(LLVMFPTy->getContext(), Val); 139 unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32; 140 Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth)); 141 assignTypeToVReg(LLVMFPTy, Res, MIRBuilder); 142 MIRBuilder.buildFConstant(Res, *ConstFP); 143 return Res; 144 } 145 146 Register SPIRVGlobalRegistry::buildGlobalVariable( 147 Register ResVReg, SPIRVType *BaseType, StringRef Name, 148 const GlobalValue *GV, SPIRV::StorageClass Storage, 149 const MachineInstr *Init, bool IsConst, bool HasLinkageTy, 150 SPIRV::LinkageType LinkageType, MachineIRBuilder &MIRBuilder, 151 bool IsInstSelector) { 152 const GlobalVariable *GVar = nullptr; 153 if (GV) 154 GVar = cast<const GlobalVariable>(GV); 155 else { 156 // If GV is not passed explicitly, use the name to find or construct 157 // the global variable. 158 Module *M = MIRBuilder.getMF().getFunction().getParent(); 159 GVar = M->getGlobalVariable(Name); 160 if (GVar == nullptr) { 161 const Type *Ty = getTypeForSPIRVType(BaseType); // TODO: check type. 162 GVar = new GlobalVariable(*M, const_cast<Type *>(Ty), false, 163 GlobalValue::ExternalLinkage, nullptr, 164 Twine(Name)); 165 } 166 GV = GVar; 167 } 168 Register Reg; 169 auto MIB = MIRBuilder.buildInstr(SPIRV::OpVariable) 170 .addDef(ResVReg) 171 .addUse(getSPIRVTypeID(BaseType)) 172 .addImm(static_cast<uint32_t>(Storage)); 173 174 if (Init != 0) { 175 MIB.addUse(Init->getOperand(0).getReg()); 176 } 177 178 // ISel may introduce a new register on this step, so we need to add it to 179 // DT and correct its type avoiding fails on the next stage. 180 if (IsInstSelector) { 181 const auto &Subtarget = CurMF->getSubtarget(); 182 constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(), 183 *Subtarget.getRegisterInfo(), 184 *Subtarget.getRegBankInfo()); 185 } 186 Reg = MIB->getOperand(0).getReg(); 187 188 // Set to Reg the same type as ResVReg has. 189 auto MRI = MIRBuilder.getMRI(); 190 assert(MRI->getType(ResVReg).isPointer() && "Pointer type is expected"); 191 if (Reg != ResVReg) { 192 LLT RegLLTy = LLT::pointer(MRI->getType(ResVReg).getAddressSpace(), 32); 193 MRI->setType(Reg, RegLLTy); 194 assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder); 195 } 196 197 // If it's a global variable with name, output OpName for it. 198 if (GVar && GVar->hasName()) 199 buildOpName(Reg, GVar->getName(), MIRBuilder); 200 201 // Output decorations for the GV. 202 // TODO: maybe move to GenerateDecorations pass. 203 if (IsConst) 204 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Constant, {}); 205 206 if (GVar && GVar->getAlign().valueOrOne().value() != 1) 207 buildOpDecorate( 208 Reg, MIRBuilder, SPIRV::Decoration::Alignment, 209 {static_cast<uint32_t>(GVar->getAlign().valueOrOne().value())}); 210 211 if (HasLinkageTy) 212 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::LinkageAttributes, 213 {static_cast<uint32_t>(LinkageType)}, Name); 214 return Reg; 215 } 216 217 SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems, 218 SPIRVType *ElemType, 219 MachineIRBuilder &MIRBuilder, 220 bool EmitIR) { 221 assert((ElemType->getOpcode() != SPIRV::OpTypeVoid) && 222 "Invalid array element type"); 223 Register NumElementsVReg = 224 buildConstantInt(NumElems, MIRBuilder, nullptr, EmitIR); 225 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeArray) 226 .addDef(createTypeVReg(MIRBuilder)) 227 .addUse(getSPIRVTypeID(ElemType)) 228 .addUse(NumElementsVReg); 229 return MIB; 230 } 231 232 SPIRVType *SPIRVGlobalRegistry::getOpTypePointer(SPIRV::StorageClass SC, 233 SPIRVType *ElemType, 234 MachineIRBuilder &MIRBuilder) { 235 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypePointer) 236 .addDef(createTypeVReg(MIRBuilder)) 237 .addImm(static_cast<uint32_t>(SC)) 238 .addUse(getSPIRVTypeID(ElemType)); 239 return MIB; 240 } 241 242 SPIRVType *SPIRVGlobalRegistry::getOpTypeFunction( 243 SPIRVType *RetType, const SmallVectorImpl<SPIRVType *> &ArgTypes, 244 MachineIRBuilder &MIRBuilder) { 245 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFunction) 246 .addDef(createTypeVReg(MIRBuilder)) 247 .addUse(getSPIRVTypeID(RetType)); 248 for (const SPIRVType *ArgType : ArgTypes) 249 MIB.addUse(getSPIRVTypeID(ArgType)); 250 return MIB; 251 } 252 253 SPIRVType *SPIRVGlobalRegistry::createSPIRVType(const Type *Ty, 254 MachineIRBuilder &MIRBuilder, 255 SPIRV::AccessQualifier AccQual, 256 bool EmitIR) { 257 if (auto IType = dyn_cast<IntegerType>(Ty)) { 258 const unsigned Width = IType->getBitWidth(); 259 return Width == 1 ? getOpTypeBool(MIRBuilder) 260 : getOpTypeInt(Width, MIRBuilder, false); 261 } 262 if (Ty->isFloatingPointTy()) 263 return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder); 264 if (Ty->isVoidTy()) 265 return getOpTypeVoid(MIRBuilder); 266 if (Ty->isVectorTy()) { 267 auto El = getOrCreateSPIRVType(cast<FixedVectorType>(Ty)->getElementType(), 268 MIRBuilder); 269 return getOpTypeVector(cast<FixedVectorType>(Ty)->getNumElements(), El, 270 MIRBuilder); 271 } 272 if (Ty->isArrayTy()) { 273 auto *El = getOrCreateSPIRVType(Ty->getArrayElementType(), MIRBuilder); 274 return getOpTypeArray(Ty->getArrayNumElements(), El, MIRBuilder, EmitIR); 275 } 276 assert(!isa<StructType>(Ty) && "Unsupported StructType"); 277 if (auto FType = dyn_cast<FunctionType>(Ty)) { 278 SPIRVType *RetTy = getOrCreateSPIRVType(FType->getReturnType(), MIRBuilder); 279 SmallVector<SPIRVType *, 4> ParamTypes; 280 for (const auto &t : FType->params()) { 281 ParamTypes.push_back(getOrCreateSPIRVType(t, MIRBuilder)); 282 } 283 return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder); 284 } 285 if (auto PType = dyn_cast<PointerType>(Ty)) { 286 Type *ElemType = PType->getPointerElementType(); 287 288 // Some OpenCL and SPIRV builtins like image2d_t are passed in as pointers, 289 // but should be treated as custom types like OpTypeImage. 290 assert(!isa<StructType>(ElemType) && "Unsupported StructType pointer"); 291 292 // Otherwise, treat it as a regular pointer type. 293 auto SC = addressSpaceToStorageClass(PType->getAddressSpace()); 294 SPIRVType *SpvElementType = getOrCreateSPIRVType( 295 ElemType, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, EmitIR); 296 return getOpTypePointer(SC, SpvElementType, MIRBuilder); 297 } 298 llvm_unreachable("Unable to convert LLVM type to SPIRVType"); 299 } 300 301 SPIRVType *SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg) const { 302 auto t = VRegToTypeMap.find(CurMF); 303 if (t != VRegToTypeMap.end()) { 304 auto tt = t->second.find(VReg); 305 if (tt != t->second.end()) 306 return tt->second; 307 } 308 return nullptr; 309 } 310 311 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType( 312 const Type *Type, MachineIRBuilder &MIRBuilder, 313 SPIRV::AccessQualifier AccessQual, bool EmitIR) { 314 SPIRVType *SpirvType = createSPIRVType(Type, MIRBuilder, AccessQual, EmitIR); 315 VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType; 316 SPIRVToLLVMType[SpirvType] = Type; 317 return SpirvType; 318 } 319 320 bool SPIRVGlobalRegistry::isScalarOfType(Register VReg, 321 unsigned TypeOpcode) const { 322 SPIRVType *Type = getSPIRVTypeForVReg(VReg); 323 assert(Type && "isScalarOfType VReg has no type assigned"); 324 return Type->getOpcode() == TypeOpcode; 325 } 326 327 bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg, 328 unsigned TypeOpcode) const { 329 SPIRVType *Type = getSPIRVTypeForVReg(VReg); 330 assert(Type && "isScalarOrVectorOfType VReg has no type assigned"); 331 if (Type->getOpcode() == TypeOpcode) 332 return true; 333 if (Type->getOpcode() == SPIRV::OpTypeVector) { 334 Register ScalarTypeVReg = Type->getOperand(1).getReg(); 335 SPIRVType *ScalarType = getSPIRVTypeForVReg(ScalarTypeVReg); 336 return ScalarType->getOpcode() == TypeOpcode; 337 } 338 return false; 339 } 340 341 unsigned 342 SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const { 343 assert(Type && "Invalid Type pointer"); 344 if (Type->getOpcode() == SPIRV::OpTypeVector) { 345 auto EleTypeReg = Type->getOperand(1).getReg(); 346 Type = getSPIRVTypeForVReg(EleTypeReg); 347 } 348 if (Type->getOpcode() == SPIRV::OpTypeInt || 349 Type->getOpcode() == SPIRV::OpTypeFloat) 350 return Type->getOperand(1).getImm(); 351 if (Type->getOpcode() == SPIRV::OpTypeBool) 352 return 1; 353 llvm_unreachable("Attempting to get bit width of non-integer/float type."); 354 } 355 356 bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const { 357 assert(Type && "Invalid Type pointer"); 358 if (Type->getOpcode() == SPIRV::OpTypeVector) { 359 auto EleTypeReg = Type->getOperand(1).getReg(); 360 Type = getSPIRVTypeForVReg(EleTypeReg); 361 } 362 if (Type->getOpcode() == SPIRV::OpTypeInt) 363 return Type->getOperand(2).getImm() != 0; 364 llvm_unreachable("Attempting to get sign of non-integer type."); 365 } 366 367 SPIRV::StorageClass 368 SPIRVGlobalRegistry::getPointerStorageClass(Register VReg) const { 369 SPIRVType *Type = getSPIRVTypeForVReg(VReg); 370 assert(Type && Type->getOpcode() == SPIRV::OpTypePointer && 371 Type->getOperand(1).isImm() && "Pointer type is expected"); 372 return static_cast<SPIRV::StorageClass>(Type->getOperand(1).getImm()); 373 } 374 375 SPIRVType * 376 SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(unsigned BitWidth, 377 MachineIRBuilder &MIRBuilder) { 378 return getOrCreateSPIRVType( 379 IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), BitWidth), 380 MIRBuilder); 381 } 382 383 SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(Type *LLVMTy, 384 MachineInstrBuilder MIB) { 385 SPIRVType *SpirvType = MIB; 386 VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType; 387 SPIRVToLLVMType[SpirvType] = LLVMTy; 388 return SpirvType; 389 } 390 391 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType( 392 unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) { 393 Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth); 394 MachineBasicBlock &BB = *I.getParent(); 395 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeInt)) 396 .addDef(createTypeVReg(CurMF->getRegInfo())) 397 .addImm(BitWidth) 398 .addImm(0); 399 return restOfCreateSPIRVType(LLVMTy, MIB); 400 } 401 402 SPIRVType * 403 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder) { 404 return getOrCreateSPIRVType( 405 IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), 1), 406 MIRBuilder); 407 } 408 409 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType( 410 SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder) { 411 return getOrCreateSPIRVType( 412 FixedVectorType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)), 413 NumElements), 414 MIRBuilder); 415 } 416 417 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType( 418 SPIRVType *BaseType, unsigned NumElements, MachineInstr &I, 419 const SPIRVInstrInfo &TII) { 420 Type *LLVMTy = FixedVectorType::get( 421 const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements); 422 MachineBasicBlock &BB = *I.getParent(); 423 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeVector)) 424 .addDef(createTypeVReg(CurMF->getRegInfo())) 425 .addUse(getSPIRVTypeID(BaseType)) 426 .addImm(NumElements); 427 return restOfCreateSPIRVType(LLVMTy, MIB); 428 } 429 430 SPIRVType * 431 SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(SPIRVType *BaseType, 432 MachineIRBuilder &MIRBuilder, 433 SPIRV::StorageClass SClass) { 434 return getOrCreateSPIRVType( 435 PointerType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)), 436 storageClassToAddressSpace(SClass)), 437 MIRBuilder); 438 } 439 440 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType( 441 SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &TII, 442 SPIRV::StorageClass SC) { 443 Type *LLVMTy = 444 PointerType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)), 445 storageClassToAddressSpace(SC)); 446 MachineBasicBlock &BB = *I.getParent(); 447 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypePointer)) 448 .addDef(createTypeVReg(CurMF->getRegInfo())) 449 .addImm(static_cast<uint32_t>(SC)) 450 .addUse(getSPIRVTypeID(BaseType)); 451 return restOfCreateSPIRVType(LLVMTy, MIB); 452 } 453