1 //===-- SPIRVGlobalRegistry.cpp - SPIR-V Global Registry --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the implementation of the SPIRVGlobalRegistry class,
10 // which is used to maintain rich type information required for SPIR-V even
11 // after lowering from LLVM IR to GMIR. It can convert an llvm::Type into
12 // an OpTypeXXX instruction, and map it to a virtual register. Also it builds
13 // and supports consistency of constants and global variables.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #include "SPIRVGlobalRegistry.h"
18 #include "SPIRV.h"
19 #include "SPIRVSubtarget.h"
20 #include "SPIRVTargetMachine.h"
21 #include "SPIRVUtils.h"
22
23 using namespace llvm;
SPIRVGlobalRegistry(unsigned PointerSize)24 SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize)
25 : PointerSize(PointerSize) {}
26
assignIntTypeToVReg(unsigned BitWidth,Register VReg,MachineInstr & I,const SPIRVInstrInfo & TII)27 SPIRVType *SPIRVGlobalRegistry::assignIntTypeToVReg(unsigned BitWidth,
28 Register VReg,
29 MachineInstr &I,
30 const SPIRVInstrInfo &TII) {
31 SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
32 assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);
33 return SpirvType;
34 }
35
assignVectTypeToVReg(SPIRVType * BaseType,unsigned NumElements,Register VReg,MachineInstr & I,const SPIRVInstrInfo & TII)36 SPIRVType *SPIRVGlobalRegistry::assignVectTypeToVReg(
37 SPIRVType *BaseType, unsigned NumElements, Register VReg, MachineInstr &I,
38 const SPIRVInstrInfo &TII) {
39 SPIRVType *SpirvType =
40 getOrCreateSPIRVVectorType(BaseType, NumElements, I, TII);
41 assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);
42 return SpirvType;
43 }
44
assignTypeToVReg(const Type * Type,Register VReg,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier AccessQual,bool EmitIR)45 SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg(
46 const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder,
47 SPIRV::AccessQualifier AccessQual, bool EmitIR) {
48
49 SPIRVType *SpirvType =
50 getOrCreateSPIRVType(Type, MIRBuilder, AccessQual, EmitIR);
51 assignSPIRVTypeToVReg(SpirvType, VReg, MIRBuilder.getMF());
52 return SpirvType;
53 }
54
assignSPIRVTypeToVReg(SPIRVType * SpirvType,Register VReg,MachineFunction & MF)55 void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType *SpirvType,
56 Register VReg,
57 MachineFunction &MF) {
58 VRegToTypeMap[&MF][VReg] = SpirvType;
59 }
60
createTypeVReg(MachineIRBuilder & MIRBuilder)61 static Register createTypeVReg(MachineIRBuilder &MIRBuilder) {
62 auto &MRI = MIRBuilder.getMF().getRegInfo();
63 auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32));
64 MRI.setRegClass(Res, &SPIRV::TYPERegClass);
65 return Res;
66 }
67
createTypeVReg(MachineRegisterInfo & MRI)68 static Register createTypeVReg(MachineRegisterInfo &MRI) {
69 auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32));
70 MRI.setRegClass(Res, &SPIRV::TYPERegClass);
71 return Res;
72 }
73
getOpTypeBool(MachineIRBuilder & MIRBuilder)74 SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
75 return MIRBuilder.buildInstr(SPIRV::OpTypeBool)
76 .addDef(createTypeVReg(MIRBuilder));
77 }
78
getOpTypeInt(uint32_t Width,MachineIRBuilder & MIRBuilder,bool IsSigned)79 SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width,
80 MachineIRBuilder &MIRBuilder,
81 bool IsSigned) {
82 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeInt)
83 .addDef(createTypeVReg(MIRBuilder))
84 .addImm(Width)
85 .addImm(IsSigned ? 1 : 0);
86 return MIB;
87 }
88
getOpTypeFloat(uint32_t Width,MachineIRBuilder & MIRBuilder)89 SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
90 MachineIRBuilder &MIRBuilder) {
91 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
92 .addDef(createTypeVReg(MIRBuilder))
93 .addImm(Width);
94 return MIB;
95 }
96
getOpTypeVoid(MachineIRBuilder & MIRBuilder)97 SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) {
98 return MIRBuilder.buildInstr(SPIRV::OpTypeVoid)
99 .addDef(createTypeVReg(MIRBuilder));
100 }
101
getOpTypeVector(uint32_t NumElems,SPIRVType * ElemType,MachineIRBuilder & MIRBuilder)102 SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems,
103 SPIRVType *ElemType,
104 MachineIRBuilder &MIRBuilder) {
105 auto EleOpc = ElemType->getOpcode();
106 assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat ||
107 EleOpc == SPIRV::OpTypeBool) &&
108 "Invalid vector element type");
109
110 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeVector)
111 .addDef(createTypeVReg(MIRBuilder))
112 .addUse(getSPIRVTypeID(ElemType))
113 .addImm(NumElems);
114 return MIB;
115 }
116
117 std::tuple<Register, ConstantInt *, bool>
getOrCreateConstIntReg(uint64_t Val,SPIRVType * SpvType,MachineIRBuilder * MIRBuilder,MachineInstr * I,const SPIRVInstrInfo * TII)118 SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType,
119 MachineIRBuilder *MIRBuilder,
120 MachineInstr *I,
121 const SPIRVInstrInfo *TII) {
122 const IntegerType *LLVMIntTy;
123 if (SpvType)
124 LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType));
125 else
126 LLVMIntTy = IntegerType::getInt32Ty(CurMF->getFunction().getContext());
127 bool NewInstr = false;
128 // Find a constant in DT or build a new one.
129 ConstantInt *CI = ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val);
130 Register Res = DT.find(CI, CurMF);
131 if (!Res.isValid()) {
132 unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
133 LLT LLTy = LLT::scalar(32);
134 Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
135 if (MIRBuilder)
136 assignTypeToVReg(LLVMIntTy, Res, *MIRBuilder);
137 else
138 assignIntTypeToVReg(BitWidth, Res, *I, *TII);
139 DT.add(CI, CurMF, Res);
140 NewInstr = true;
141 }
142 return std::make_tuple(Res, CI, NewInstr);
143 }
144
getOrCreateConstInt(uint64_t Val,MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII)145 Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
146 SPIRVType *SpvType,
147 const SPIRVInstrInfo &TII) {
148 assert(SpvType);
149 ConstantInt *CI;
150 Register Res;
151 bool New;
152 std::tie(Res, CI, New) =
153 getOrCreateConstIntReg(Val, SpvType, nullptr, &I, &TII);
154 // If we have found Res register which is defined by the passed G_CONSTANT
155 // machine instruction, a new constant instruction should be created.
156 if (!New && (!I.getOperand(0).isReg() || Res != I.getOperand(0).getReg()))
157 return Res;
158 MachineInstrBuilder MIB;
159 MachineBasicBlock &BB = *I.getParent();
160 if (Val) {
161 MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantI))
162 .addDef(Res)
163 .addUse(getSPIRVTypeID(SpvType));
164 addNumImm(APInt(getScalarOrVectorBitWidth(SpvType), Val), MIB);
165 } else {
166 MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
167 .addDef(Res)
168 .addUse(getSPIRVTypeID(SpvType));
169 }
170 const auto &ST = CurMF->getSubtarget();
171 constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),
172 *ST.getRegisterInfo(), *ST.getRegBankInfo());
173 return Res;
174 }
175
buildConstantInt(uint64_t Val,MachineIRBuilder & MIRBuilder,SPIRVType * SpvType,bool EmitIR)176 Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,
177 MachineIRBuilder &MIRBuilder,
178 SPIRVType *SpvType,
179 bool EmitIR) {
180 auto &MF = MIRBuilder.getMF();
181 const IntegerType *LLVMIntTy;
182 if (SpvType)
183 LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType));
184 else
185 LLVMIntTy = IntegerType::getInt32Ty(MF.getFunction().getContext());
186 // Find a constant in DT or build a new one.
187 const auto ConstInt =
188 ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val);
189 Register Res = DT.find(ConstInt, &MF);
190 if (!Res.isValid()) {
191 unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
192 LLT LLTy = LLT::scalar(EmitIR ? BitWidth : 32);
193 Res = MF.getRegInfo().createGenericVirtualRegister(LLTy);
194 assignTypeToVReg(LLVMIntTy, Res, MIRBuilder,
195 SPIRV::AccessQualifier::ReadWrite, EmitIR);
196 DT.add(ConstInt, &MIRBuilder.getMF(), Res);
197 if (EmitIR) {
198 MIRBuilder.buildConstant(Res, *ConstInt);
199 } else {
200 MachineInstrBuilder MIB;
201 if (Val) {
202 assert(SpvType);
203 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
204 .addDef(Res)
205 .addUse(getSPIRVTypeID(SpvType));
206 addNumImm(APInt(BitWidth, Val), MIB);
207 } else {
208 assert(SpvType);
209 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
210 .addDef(Res)
211 .addUse(getSPIRVTypeID(SpvType));
212 }
213 const auto &Subtarget = CurMF->getSubtarget();
214 constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
215 *Subtarget.getRegisterInfo(),
216 *Subtarget.getRegBankInfo());
217 }
218 }
219 return Res;
220 }
221
buildConstantFP(APFloat Val,MachineIRBuilder & MIRBuilder,SPIRVType * SpvType)222 Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val,
223 MachineIRBuilder &MIRBuilder,
224 SPIRVType *SpvType) {
225 auto &MF = MIRBuilder.getMF();
226 const Type *LLVMFPTy;
227 if (SpvType) {
228 LLVMFPTy = getTypeForSPIRVType(SpvType);
229 assert(LLVMFPTy->isFloatingPointTy());
230 } else {
231 LLVMFPTy = IntegerType::getFloatTy(MF.getFunction().getContext());
232 }
233 // Find a constant in DT or build a new one.
234 const auto ConstFP = ConstantFP::get(LLVMFPTy->getContext(), Val);
235 Register Res = DT.find(ConstFP, &MF);
236 if (!Res.isValid()) {
237 unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
238 Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth));
239 assignTypeToVReg(LLVMFPTy, Res, MIRBuilder);
240 DT.add(ConstFP, &MF, Res);
241 MIRBuilder.buildFConstant(Res, *ConstFP);
242 }
243 return Res;
244 }
245
246 Register
getOrCreateConsIntVector(uint64_t Val,MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII)247 SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val, MachineInstr &I,
248 SPIRVType *SpvType,
249 const SPIRVInstrInfo &TII) {
250 const Type *LLVMTy = getTypeForSPIRVType(SpvType);
251 assert(LLVMTy->isVectorTy());
252 const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
253 Type *LLVMBaseTy = LLVMVecTy->getElementType();
254 // Find a constant vector in DT or build a new one.
255 const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
256 auto ConstVec =
257 ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstInt);
258 Register Res = DT.find(ConstVec, CurMF);
259 if (!Res.isValid()) {
260 unsigned BitWidth = getScalarOrVectorBitWidth(SpvType);
261 SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
262 // SpvScalConst should be created before SpvVecConst to avoid undefined ID
263 // error on validation.
264 // TODO: can moved below once sorting of types/consts/defs is implemented.
265 Register SpvScalConst;
266 if (Val)
267 SpvScalConst = getOrCreateConstInt(Val, I, SpvBaseType, TII);
268 // TODO: maybe use bitwidth of base type.
269 LLT LLTy = LLT::scalar(32);
270 Register SpvVecConst =
271 CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
272 const unsigned ElemCnt = SpvType->getOperand(2).getImm();
273 assignVectTypeToVReg(SpvBaseType, ElemCnt, SpvVecConst, I, TII);
274 DT.add(ConstVec, CurMF, SpvVecConst);
275 MachineInstrBuilder MIB;
276 MachineBasicBlock &BB = *I.getParent();
277 if (Val) {
278 MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantComposite))
279 .addDef(SpvVecConst)
280 .addUse(getSPIRVTypeID(SpvType));
281 for (unsigned i = 0; i < ElemCnt; ++i)
282 MIB.addUse(SpvScalConst);
283 } else {
284 MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
285 .addDef(SpvVecConst)
286 .addUse(getSPIRVTypeID(SpvType));
287 }
288 const auto &Subtarget = CurMF->getSubtarget();
289 constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
290 *Subtarget.getRegisterInfo(),
291 *Subtarget.getRegBankInfo());
292 return SpvVecConst;
293 }
294 return Res;
295 }
296
buildGlobalVariable(Register ResVReg,SPIRVType * BaseType,StringRef Name,const GlobalValue * GV,SPIRV::StorageClass Storage,const MachineInstr * Init,bool IsConst,bool HasLinkageTy,SPIRV::LinkageType LinkageType,MachineIRBuilder & MIRBuilder,bool IsInstSelector)297 Register SPIRVGlobalRegistry::buildGlobalVariable(
298 Register ResVReg, SPIRVType *BaseType, StringRef Name,
299 const GlobalValue *GV, SPIRV::StorageClass Storage,
300 const MachineInstr *Init, bool IsConst, bool HasLinkageTy,
301 SPIRV::LinkageType LinkageType, MachineIRBuilder &MIRBuilder,
302 bool IsInstSelector) {
303 const GlobalVariable *GVar = nullptr;
304 if (GV)
305 GVar = cast<const GlobalVariable>(GV);
306 else {
307 // If GV is not passed explicitly, use the name to find or construct
308 // the global variable.
309 Module *M = MIRBuilder.getMF().getFunction().getParent();
310 GVar = M->getGlobalVariable(Name);
311 if (GVar == nullptr) {
312 const Type *Ty = getTypeForSPIRVType(BaseType); // TODO: check type.
313 GVar = new GlobalVariable(*M, const_cast<Type *>(Ty), false,
314 GlobalValue::ExternalLinkage, nullptr,
315 Twine(Name));
316 }
317 GV = GVar;
318 }
319 Register Reg = DT.find(GVar, &MIRBuilder.getMF());
320 if (Reg.isValid()) {
321 if (Reg != ResVReg)
322 MIRBuilder.buildCopy(ResVReg, Reg);
323 return ResVReg;
324 }
325
326 auto MIB = MIRBuilder.buildInstr(SPIRV::OpVariable)
327 .addDef(ResVReg)
328 .addUse(getSPIRVTypeID(BaseType))
329 .addImm(static_cast<uint32_t>(Storage));
330
331 if (Init != 0) {
332 MIB.addUse(Init->getOperand(0).getReg());
333 }
334
335 // ISel may introduce a new register on this step, so we need to add it to
336 // DT and correct its type avoiding fails on the next stage.
337 if (IsInstSelector) {
338 const auto &Subtarget = CurMF->getSubtarget();
339 constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
340 *Subtarget.getRegisterInfo(),
341 *Subtarget.getRegBankInfo());
342 }
343 Reg = MIB->getOperand(0).getReg();
344 DT.add(GVar, &MIRBuilder.getMF(), Reg);
345
346 // Set to Reg the same type as ResVReg has.
347 auto MRI = MIRBuilder.getMRI();
348 assert(MRI->getType(ResVReg).isPointer() && "Pointer type is expected");
349 if (Reg != ResVReg) {
350 LLT RegLLTy = LLT::pointer(MRI->getType(ResVReg).getAddressSpace(), 32);
351 MRI->setType(Reg, RegLLTy);
352 assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF());
353 }
354
355 // If it's a global variable with name, output OpName for it.
356 if (GVar && GVar->hasName())
357 buildOpName(Reg, GVar->getName(), MIRBuilder);
358
359 // Output decorations for the GV.
360 // TODO: maybe move to GenerateDecorations pass.
361 if (IsConst)
362 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Constant, {});
363
364 if (GVar && GVar->getAlign().valueOrOne().value() != 1)
365 buildOpDecorate(
366 Reg, MIRBuilder, SPIRV::Decoration::Alignment,
367 {static_cast<uint32_t>(GVar->getAlign().valueOrOne().value())});
368
369 if (HasLinkageTy)
370 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::LinkageAttributes,
371 {static_cast<uint32_t>(LinkageType)}, Name);
372 return Reg;
373 }
374
getOpTypeArray(uint32_t NumElems,SPIRVType * ElemType,MachineIRBuilder & MIRBuilder,bool EmitIR)375 SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems,
376 SPIRVType *ElemType,
377 MachineIRBuilder &MIRBuilder,
378 bool EmitIR) {
379 assert((ElemType->getOpcode() != SPIRV::OpTypeVoid) &&
380 "Invalid array element type");
381 Register NumElementsVReg =
382 buildConstantInt(NumElems, MIRBuilder, nullptr, EmitIR);
383 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeArray)
384 .addDef(createTypeVReg(MIRBuilder))
385 .addUse(getSPIRVTypeID(ElemType))
386 .addUse(NumElementsVReg);
387 return MIB;
388 }
389
getOpTypeOpaque(const StructType * Ty,MachineIRBuilder & MIRBuilder)390 SPIRVType *SPIRVGlobalRegistry::getOpTypeOpaque(const StructType *Ty,
391 MachineIRBuilder &MIRBuilder) {
392 assert(Ty->hasName());
393 const StringRef Name = Ty->hasName() ? Ty->getName() : "";
394 Register ResVReg = createTypeVReg(MIRBuilder);
395 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeOpaque).addDef(ResVReg);
396 addStringImm(Name, MIB);
397 buildOpName(ResVReg, Name, MIRBuilder);
398 return MIB;
399 }
400
getOpTypeStruct(const StructType * Ty,MachineIRBuilder & MIRBuilder,bool EmitIR)401 SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty,
402 MachineIRBuilder &MIRBuilder,
403 bool EmitIR) {
404 SmallVector<Register, 4> FieldTypes;
405 for (const auto &Elem : Ty->elements()) {
406 SPIRVType *ElemTy = findSPIRVType(Elem, MIRBuilder);
407 assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid &&
408 "Invalid struct element type");
409 FieldTypes.push_back(getSPIRVTypeID(ElemTy));
410 }
411 Register ResVReg = createTypeVReg(MIRBuilder);
412 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeStruct).addDef(ResVReg);
413 for (const auto &Ty : FieldTypes)
414 MIB.addUse(Ty);
415 if (Ty->hasName())
416 buildOpName(ResVReg, Ty->getName(), MIRBuilder);
417 if (Ty->isPacked())
418 buildOpDecorate(ResVReg, MIRBuilder, SPIRV::Decoration::CPacked, {});
419 return MIB;
420 }
421
isOpenCLBuiltinType(const StructType * SType)422 static bool isOpenCLBuiltinType(const StructType *SType) {
423 return SType->isOpaque() && SType->hasName() &&
424 SType->getName().startswith("opencl.");
425 }
426
isSPIRVBuiltinType(const StructType * SType)427 static bool isSPIRVBuiltinType(const StructType *SType) {
428 return SType->isOpaque() && SType->hasName() &&
429 SType->getName().startswith("spirv.");
430 }
431
isSpecialType(const Type * Ty)432 static bool isSpecialType(const Type *Ty) {
433 if (auto PType = dyn_cast<PointerType>(Ty)) {
434 if (!PType->isOpaque())
435 Ty = PType->getNonOpaquePointerElementType();
436 }
437 if (auto SType = dyn_cast<StructType>(Ty))
438 return isOpenCLBuiltinType(SType) || isSPIRVBuiltinType(SType);
439 return false;
440 }
441
getOpTypePointer(SPIRV::StorageClass SC,SPIRVType * ElemType,MachineIRBuilder & MIRBuilder,Register Reg)442 SPIRVType *SPIRVGlobalRegistry::getOpTypePointer(SPIRV::StorageClass SC,
443 SPIRVType *ElemType,
444 MachineIRBuilder &MIRBuilder,
445 Register Reg) {
446 if (!Reg.isValid())
447 Reg = createTypeVReg(MIRBuilder);
448 return MIRBuilder.buildInstr(SPIRV::OpTypePointer)
449 .addDef(Reg)
450 .addImm(static_cast<uint32_t>(SC))
451 .addUse(getSPIRVTypeID(ElemType));
452 }
453
454 SPIRVType *
getOpTypeForwardPointer(SPIRV::StorageClass SC,MachineIRBuilder & MIRBuilder)455 SPIRVGlobalRegistry::getOpTypeForwardPointer(SPIRV::StorageClass SC,
456 MachineIRBuilder &MIRBuilder) {
457 return MIRBuilder.buildInstr(SPIRV::OpTypeForwardPointer)
458 .addUse(createTypeVReg(MIRBuilder))
459 .addImm(static_cast<uint32_t>(SC));
460 }
461
getOpTypeFunction(SPIRVType * RetType,const SmallVectorImpl<SPIRVType * > & ArgTypes,MachineIRBuilder & MIRBuilder)462 SPIRVType *SPIRVGlobalRegistry::getOpTypeFunction(
463 SPIRVType *RetType, const SmallVectorImpl<SPIRVType *> &ArgTypes,
464 MachineIRBuilder &MIRBuilder) {
465 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFunction)
466 .addDef(createTypeVReg(MIRBuilder))
467 .addUse(getSPIRVTypeID(RetType));
468 for (const SPIRVType *ArgType : ArgTypes)
469 MIB.addUse(getSPIRVTypeID(ArgType));
470 return MIB;
471 }
472
getOrCreateOpTypeFunctionWithArgs(const Type * Ty,SPIRVType * RetType,const SmallVectorImpl<SPIRVType * > & ArgTypes,MachineIRBuilder & MIRBuilder)473 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeFunctionWithArgs(
474 const Type *Ty, SPIRVType *RetType,
475 const SmallVectorImpl<SPIRVType *> &ArgTypes,
476 MachineIRBuilder &MIRBuilder) {
477 Register Reg = DT.find(Ty, &MIRBuilder.getMF());
478 if (Reg.isValid())
479 return getSPIRVTypeForVReg(Reg);
480 SPIRVType *SpirvType = getOpTypeFunction(RetType, ArgTypes, MIRBuilder);
481 return finishCreatingSPIRVType(Ty, SpirvType);
482 }
483
findSPIRVType(const Type * Ty,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier AccQual,bool EmitIR)484 SPIRVType *SPIRVGlobalRegistry::findSPIRVType(const Type *Ty,
485 MachineIRBuilder &MIRBuilder,
486 SPIRV::AccessQualifier AccQual,
487 bool EmitIR) {
488 Register Reg = DT.find(Ty, &MIRBuilder.getMF());
489 if (Reg.isValid())
490 return getSPIRVTypeForVReg(Reg);
491 if (ForwardPointerTypes.find(Ty) != ForwardPointerTypes.end())
492 return ForwardPointerTypes[Ty];
493 return restOfCreateSPIRVType(Ty, MIRBuilder, AccQual, EmitIR);
494 }
495
getSPIRVTypeID(const SPIRVType * SpirvType) const496 Register SPIRVGlobalRegistry::getSPIRVTypeID(const SPIRVType *SpirvType) const {
497 assert(SpirvType && "Attempting to get type id for nullptr type.");
498 if (SpirvType->getOpcode() == SPIRV::OpTypeForwardPointer)
499 return SpirvType->uses().begin()->getReg();
500 return SpirvType->defs().begin()->getReg();
501 }
502
createSPIRVType(const Type * Ty,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier AccQual,bool EmitIR)503 SPIRVType *SPIRVGlobalRegistry::createSPIRVType(const Type *Ty,
504 MachineIRBuilder &MIRBuilder,
505 SPIRV::AccessQualifier AccQual,
506 bool EmitIR) {
507 assert(!isSpecialType(Ty));
508 auto &TypeToSPIRVTypeMap = DT.getTypes()->getAllUses();
509 auto t = TypeToSPIRVTypeMap.find(Ty);
510 if (t != TypeToSPIRVTypeMap.end()) {
511 auto tt = t->second.find(&MIRBuilder.getMF());
512 if (tt != t->second.end())
513 return getSPIRVTypeForVReg(tt->second);
514 }
515
516 if (auto IType = dyn_cast<IntegerType>(Ty)) {
517 const unsigned Width = IType->getBitWidth();
518 return Width == 1 ? getOpTypeBool(MIRBuilder)
519 : getOpTypeInt(Width, MIRBuilder, false);
520 }
521 if (Ty->isFloatingPointTy())
522 return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
523 if (Ty->isVoidTy())
524 return getOpTypeVoid(MIRBuilder);
525 if (Ty->isVectorTy()) {
526 SPIRVType *El =
527 findSPIRVType(cast<FixedVectorType>(Ty)->getElementType(), MIRBuilder);
528 return getOpTypeVector(cast<FixedVectorType>(Ty)->getNumElements(), El,
529 MIRBuilder);
530 }
531 if (Ty->isArrayTy()) {
532 SPIRVType *El = findSPIRVType(Ty->getArrayElementType(), MIRBuilder);
533 return getOpTypeArray(Ty->getArrayNumElements(), El, MIRBuilder, EmitIR);
534 }
535 if (auto SType = dyn_cast<StructType>(Ty)) {
536 if (SType->isOpaque())
537 return getOpTypeOpaque(SType, MIRBuilder);
538 return getOpTypeStruct(SType, MIRBuilder, EmitIR);
539 }
540 if (auto FType = dyn_cast<FunctionType>(Ty)) {
541 SPIRVType *RetTy = findSPIRVType(FType->getReturnType(), MIRBuilder);
542 SmallVector<SPIRVType *, 4> ParamTypes;
543 for (const auto &t : FType->params()) {
544 ParamTypes.push_back(findSPIRVType(t, MIRBuilder));
545 }
546 return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder);
547 }
548 if (auto PType = dyn_cast<PointerType>(Ty)) {
549 SPIRVType *SpvElementType;
550 // At the moment, all opaque pointers correspond to i8 element type.
551 // TODO: change the implementation once opaque pointers are supported
552 // in the SPIR-V specification.
553 if (PType->isOpaque())
554 SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder);
555 else
556 SpvElementType =
557 findSPIRVType(PType->getNonOpaquePointerElementType(), MIRBuilder,
558 SPIRV::AccessQualifier::ReadWrite, EmitIR);
559 auto SC = addressSpaceToStorageClass(PType->getAddressSpace());
560 // Null pointer means we have a loop in type definitions, make and
561 // return corresponding OpTypeForwardPointer.
562 if (SpvElementType == nullptr) {
563 if (ForwardPointerTypes.find(Ty) == ForwardPointerTypes.end())
564 ForwardPointerTypes[PType] = getOpTypeForwardPointer(SC, MIRBuilder);
565 return ForwardPointerTypes[PType];
566 }
567 Register Reg(0);
568 // If we have forward pointer associated with this type, use its register
569 // operand to create OpTypePointer.
570 if (ForwardPointerTypes.find(PType) != ForwardPointerTypes.end())
571 Reg = getSPIRVTypeID(ForwardPointerTypes[PType]);
572
573 return getOpTypePointer(SC, SpvElementType, MIRBuilder, Reg);
574 }
575 llvm_unreachable("Unable to convert LLVM type to SPIRVType");
576 }
577
restOfCreateSPIRVType(const Type * Ty,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier AccessQual,bool EmitIR)578 SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
579 const Type *Ty, MachineIRBuilder &MIRBuilder,
580 SPIRV::AccessQualifier AccessQual, bool EmitIR) {
581 if (TypesInProcessing.count(Ty) && !Ty->isPointerTy())
582 return nullptr;
583 TypesInProcessing.insert(Ty);
584 SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
585 TypesInProcessing.erase(Ty);
586 VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;
587 SPIRVToLLVMType[SpirvType] = Ty;
588 Register Reg = DT.find(Ty, &MIRBuilder.getMF());
589 // Do not add OpTypeForwardPointer to DT, a corresponding normal pointer type
590 // will be added later. For special types it is already added to DT.
591 if (SpirvType->getOpcode() != SPIRV::OpTypeForwardPointer && !Reg.isValid() &&
592 !isSpecialType(Ty))
593 DT.add(Ty, &MIRBuilder.getMF(), getSPIRVTypeID(SpirvType));
594
595 return SpirvType;
596 }
597
getSPIRVTypeForVReg(Register VReg) const598 SPIRVType *SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg) const {
599 auto t = VRegToTypeMap.find(CurMF);
600 if (t != VRegToTypeMap.end()) {
601 auto tt = t->second.find(VReg);
602 if (tt != t->second.end())
603 return tt->second;
604 }
605 return nullptr;
606 }
607
getOrCreateSPIRVType(const Type * Ty,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier AccessQual,bool EmitIR)608 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(
609 const Type *Ty, MachineIRBuilder &MIRBuilder,
610 SPIRV::AccessQualifier AccessQual, bool EmitIR) {
611 Register Reg = DT.find(Ty, &MIRBuilder.getMF());
612 if (Reg.isValid())
613 return getSPIRVTypeForVReg(Reg);
614 TypesInProcessing.clear();
615 SPIRVType *STy = restOfCreateSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
616 // Create normal pointer types for the corresponding OpTypeForwardPointers.
617 for (auto &CU : ForwardPointerTypes) {
618 const Type *Ty2 = CU.first;
619 SPIRVType *STy2 = CU.second;
620 if ((Reg = DT.find(Ty2, &MIRBuilder.getMF())).isValid())
621 STy2 = getSPIRVTypeForVReg(Reg);
622 else
623 STy2 = restOfCreateSPIRVType(Ty2, MIRBuilder, AccessQual, EmitIR);
624 if (Ty == Ty2)
625 STy = STy2;
626 }
627 ForwardPointerTypes.clear();
628 return STy;
629 }
630
isScalarOfType(Register VReg,unsigned TypeOpcode) const631 bool SPIRVGlobalRegistry::isScalarOfType(Register VReg,
632 unsigned TypeOpcode) const {
633 SPIRVType *Type = getSPIRVTypeForVReg(VReg);
634 assert(Type && "isScalarOfType VReg has no type assigned");
635 return Type->getOpcode() == TypeOpcode;
636 }
637
isScalarOrVectorOfType(Register VReg,unsigned TypeOpcode) const638 bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg,
639 unsigned TypeOpcode) const {
640 SPIRVType *Type = getSPIRVTypeForVReg(VReg);
641 assert(Type && "isScalarOrVectorOfType VReg has no type assigned");
642 if (Type->getOpcode() == TypeOpcode)
643 return true;
644 if (Type->getOpcode() == SPIRV::OpTypeVector) {
645 Register ScalarTypeVReg = Type->getOperand(1).getReg();
646 SPIRVType *ScalarType = getSPIRVTypeForVReg(ScalarTypeVReg);
647 return ScalarType->getOpcode() == TypeOpcode;
648 }
649 return false;
650 }
651
652 unsigned
getScalarOrVectorBitWidth(const SPIRVType * Type) const653 SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const {
654 assert(Type && "Invalid Type pointer");
655 if (Type->getOpcode() == SPIRV::OpTypeVector) {
656 auto EleTypeReg = Type->getOperand(1).getReg();
657 Type = getSPIRVTypeForVReg(EleTypeReg);
658 }
659 if (Type->getOpcode() == SPIRV::OpTypeInt ||
660 Type->getOpcode() == SPIRV::OpTypeFloat)
661 return Type->getOperand(1).getImm();
662 if (Type->getOpcode() == SPIRV::OpTypeBool)
663 return 1;
664 llvm_unreachable("Attempting to get bit width of non-integer/float type.");
665 }
666
isScalarOrVectorSigned(const SPIRVType * Type) const667 bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const {
668 assert(Type && "Invalid Type pointer");
669 if (Type->getOpcode() == SPIRV::OpTypeVector) {
670 auto EleTypeReg = Type->getOperand(1).getReg();
671 Type = getSPIRVTypeForVReg(EleTypeReg);
672 }
673 if (Type->getOpcode() == SPIRV::OpTypeInt)
674 return Type->getOperand(2).getImm() != 0;
675 llvm_unreachable("Attempting to get sign of non-integer type.");
676 }
677
678 SPIRV::StorageClass
getPointerStorageClass(Register VReg) const679 SPIRVGlobalRegistry::getPointerStorageClass(Register VReg) const {
680 SPIRVType *Type = getSPIRVTypeForVReg(VReg);
681 assert(Type && Type->getOpcode() == SPIRV::OpTypePointer &&
682 Type->getOperand(1).isImm() && "Pointer type is expected");
683 return static_cast<SPIRV::StorageClass>(Type->getOperand(1).getImm());
684 }
685
686 SPIRVType *
getOrCreateSPIRVIntegerType(unsigned BitWidth,MachineIRBuilder & MIRBuilder)687 SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(unsigned BitWidth,
688 MachineIRBuilder &MIRBuilder) {
689 return getOrCreateSPIRVType(
690 IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), BitWidth),
691 MIRBuilder);
692 }
693
finishCreatingSPIRVType(const Type * LLVMTy,SPIRVType * SpirvType)694 SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy,
695 SPIRVType *SpirvType) {
696 assert(CurMF == SpirvType->getMF());
697 VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType;
698 SPIRVToLLVMType[SpirvType] = LLVMTy;
699 DT.add(LLVMTy, CurMF, getSPIRVTypeID(SpirvType));
700 return SpirvType;
701 }
702
getOrCreateSPIRVIntegerType(unsigned BitWidth,MachineInstr & I,const SPIRVInstrInfo & TII)703 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
704 unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
705 Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth);
706 Register Reg = DT.find(LLVMTy, CurMF);
707 if (Reg.isValid())
708 return getSPIRVTypeForVReg(Reg);
709 MachineBasicBlock &BB = *I.getParent();
710 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeInt))
711 .addDef(createTypeVReg(CurMF->getRegInfo()))
712 .addImm(BitWidth)
713 .addImm(0);
714 return finishCreatingSPIRVType(LLVMTy, MIB);
715 }
716
717 SPIRVType *
getOrCreateSPIRVBoolType(MachineIRBuilder & MIRBuilder)718 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder) {
719 return getOrCreateSPIRVType(
720 IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), 1),
721 MIRBuilder);
722 }
723
724 SPIRVType *
getOrCreateSPIRVBoolType(MachineInstr & I,const SPIRVInstrInfo & TII)725 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineInstr &I,
726 const SPIRVInstrInfo &TII) {
727 Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), 1);
728 Register Reg = DT.find(LLVMTy, CurMF);
729 if (Reg.isValid())
730 return getSPIRVTypeForVReg(Reg);
731 MachineBasicBlock &BB = *I.getParent();
732 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeBool))
733 .addDef(createTypeVReg(CurMF->getRegInfo()));
734 return finishCreatingSPIRVType(LLVMTy, MIB);
735 }
736
getOrCreateSPIRVVectorType(SPIRVType * BaseType,unsigned NumElements,MachineIRBuilder & MIRBuilder)737 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
738 SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder) {
739 return getOrCreateSPIRVType(
740 FixedVectorType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)),
741 NumElements),
742 MIRBuilder);
743 }
744
getOrCreateSPIRVVectorType(SPIRVType * BaseType,unsigned NumElements,MachineInstr & I,const SPIRVInstrInfo & TII)745 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
746 SPIRVType *BaseType, unsigned NumElements, MachineInstr &I,
747 const SPIRVInstrInfo &TII) {
748 Type *LLVMTy = FixedVectorType::get(
749 const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements);
750 Register Reg = DT.find(LLVMTy, CurMF);
751 if (Reg.isValid())
752 return getSPIRVTypeForVReg(Reg);
753 MachineBasicBlock &BB = *I.getParent();
754 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeVector))
755 .addDef(createTypeVReg(CurMF->getRegInfo()))
756 .addUse(getSPIRVTypeID(BaseType))
757 .addImm(NumElements);
758 return finishCreatingSPIRVType(LLVMTy, MIB);
759 }
760
761 SPIRVType *
getOrCreateSPIRVPointerType(SPIRVType * BaseType,MachineIRBuilder & MIRBuilder,SPIRV::StorageClass SClass)762 SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(SPIRVType *BaseType,
763 MachineIRBuilder &MIRBuilder,
764 SPIRV::StorageClass SClass) {
765 return getOrCreateSPIRVType(
766 PointerType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)),
767 storageClassToAddressSpace(SClass)),
768 MIRBuilder);
769 }
770
getOrCreateSPIRVPointerType(SPIRVType * BaseType,MachineInstr & I,const SPIRVInstrInfo & TII,SPIRV::StorageClass SC)771 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
772 SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &TII,
773 SPIRV::StorageClass SC) {
774 Type *LLVMTy =
775 PointerType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)),
776 storageClassToAddressSpace(SC));
777 Register Reg = DT.find(LLVMTy, CurMF);
778 if (Reg.isValid())
779 return getSPIRVTypeForVReg(Reg);
780 MachineBasicBlock &BB = *I.getParent();
781 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypePointer))
782 .addDef(createTypeVReg(CurMF->getRegInfo()))
783 .addImm(static_cast<uint32_t>(SC))
784 .addUse(getSPIRVTypeID(BaseType));
785 return finishCreatingSPIRVType(LLVMTy, MIB);
786 }
787
getOrCreateUndef(MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII)788 Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I,
789 SPIRVType *SpvType,
790 const SPIRVInstrInfo &TII) {
791 assert(SpvType);
792 const Type *LLVMTy = getTypeForSPIRVType(SpvType);
793 assert(LLVMTy);
794 // Find a constant in DT or build a new one.
795 UndefValue *UV = UndefValue::get(const_cast<Type *>(LLVMTy));
796 Register Res = DT.find(UV, CurMF);
797 if (Res.isValid())
798 return Res;
799 LLT LLTy = LLT::scalar(32);
800 Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
801 assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
802 DT.add(UV, CurMF, Res);
803
804 MachineInstrBuilder MIB;
805 MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpUndef))
806 .addDef(Res)
807 .addUse(getSPIRVTypeID(SpvType));
808 const auto &ST = CurMF->getSubtarget();
809 constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),
810 *ST.getRegisterInfo(), *ST.getRegBankInfo());
811 return Res;
812 }
813