1 //===-- SPIRVGlobalRegistry.cpp - SPIR-V Global Registry --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the implementation of the SPIRVGlobalRegistry class,
10 // which is used to maintain rich type information required for SPIR-V even
11 // after lowering from LLVM IR to GMIR. It can convert an llvm::Type into
12 // an OpTypeXXX instruction, and map it to a virtual register. Also it builds
13 // and supports consistency of constants and global variables.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "SPIRVGlobalRegistry.h"
18 #include "SPIRV.h"
19 #include "SPIRVSubtarget.h"
20 #include "SPIRVTargetMachine.h"
21 #include "SPIRVUtils.h"
22 
23 using namespace llvm;
24 SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize)
25     : PointerSize(PointerSize) {}
26 
27 SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg(
28     const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder,
29     SPIRV::AccessQualifier AccessQual, bool EmitIR) {
30 
31   SPIRVType *SpirvType =
32       getOrCreateSPIRVType(Type, MIRBuilder, AccessQual, EmitIR);
33   assignSPIRVTypeToVReg(SpirvType, VReg, MIRBuilder);
34   return SpirvType;
35 }
36 
37 void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType *SpirvType,
38                                                 Register VReg,
39                                                 MachineIRBuilder &MIRBuilder) {
40   VRegToTypeMap[&MIRBuilder.getMF()][VReg] = SpirvType;
41 }
42 
43 static Register createTypeVReg(MachineIRBuilder &MIRBuilder) {
44   auto &MRI = MIRBuilder.getMF().getRegInfo();
45   auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32));
46   MRI.setRegClass(Res, &SPIRV::TYPERegClass);
47   return Res;
48 }
49 
50 static Register createTypeVReg(MachineRegisterInfo &MRI) {
51   auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32));
52   MRI.setRegClass(Res, &SPIRV::TYPERegClass);
53   return Res;
54 }
55 
56 SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
57   return MIRBuilder.buildInstr(SPIRV::OpTypeBool)
58       .addDef(createTypeVReg(MIRBuilder));
59 }
60 
61 SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width,
62                                              MachineIRBuilder &MIRBuilder,
63                                              bool IsSigned) {
64   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeInt)
65                  .addDef(createTypeVReg(MIRBuilder))
66                  .addImm(Width)
67                  .addImm(IsSigned ? 1 : 0);
68   return MIB;
69 }
70 
71 SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
72                                                MachineIRBuilder &MIRBuilder) {
73   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
74                  .addDef(createTypeVReg(MIRBuilder))
75                  .addImm(Width);
76   return MIB;
77 }
78 
79 SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) {
80   return MIRBuilder.buildInstr(SPIRV::OpTypeVoid)
81       .addDef(createTypeVReg(MIRBuilder));
82 }
83 
84 SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems,
85                                                 SPIRVType *ElemType,
86                                                 MachineIRBuilder &MIRBuilder) {
87   auto EleOpc = ElemType->getOpcode();
88   assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat ||
89           EleOpc == SPIRV::OpTypeBool) &&
90          "Invalid vector element type");
91 
92   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeVector)
93                  .addDef(createTypeVReg(MIRBuilder))
94                  .addUse(getSPIRVTypeID(ElemType))
95                  .addImm(NumElems);
96   return MIB;
97 }
98 
99 Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,
100                                                MachineIRBuilder &MIRBuilder,
101                                                SPIRVType *SpvType,
102                                                bool EmitIR) {
103   auto &MF = MIRBuilder.getMF();
104   Register Res;
105   const IntegerType *LLVMIntTy;
106   if (SpvType)
107     LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType));
108   else
109     LLVMIntTy = IntegerType::getInt32Ty(MF.getFunction().getContext());
110   // Find a constant in DT or build a new one.
111   const auto ConstInt =
112       ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val);
113   unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
114   Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth));
115   assignTypeToVReg(LLVMIntTy, Res, MIRBuilder);
116   if (EmitIR)
117     MIRBuilder.buildConstant(Res, *ConstInt);
118   else
119     MIRBuilder.buildInstr(SPIRV::OpConstantI)
120         .addDef(Res)
121         .addImm(ConstInt->getSExtValue());
122   return Res;
123 }
124 
125 Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val,
126                                               MachineIRBuilder &MIRBuilder,
127                                               SPIRVType *SpvType) {
128   auto &MF = MIRBuilder.getMF();
129   Register Res;
130   const Type *LLVMFPTy;
131   if (SpvType) {
132     LLVMFPTy = getTypeForSPIRVType(SpvType);
133     assert(LLVMFPTy->isFloatingPointTy());
134   } else {
135     LLVMFPTy = IntegerType::getFloatTy(MF.getFunction().getContext());
136   }
137   // Find a constant in DT or build a new one.
138   const auto ConstFP = ConstantFP::get(LLVMFPTy->getContext(), Val);
139   unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
140   Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth));
141   assignTypeToVReg(LLVMFPTy, Res, MIRBuilder);
142   MIRBuilder.buildFConstant(Res, *ConstFP);
143   return Res;
144 }
145 
146 Register SPIRVGlobalRegistry::buildGlobalVariable(
147     Register ResVReg, SPIRVType *BaseType, StringRef Name,
148     const GlobalValue *GV, SPIRV::StorageClass Storage,
149     const MachineInstr *Init, bool IsConst, bool HasLinkageTy,
150     SPIRV::LinkageType LinkageType, MachineIRBuilder &MIRBuilder,
151     bool IsInstSelector) {
152   const GlobalVariable *GVar = nullptr;
153   if (GV)
154     GVar = cast<const GlobalVariable>(GV);
155   else {
156     // If GV is not passed explicitly, use the name to find or construct
157     // the global variable.
158     Module *M = MIRBuilder.getMF().getFunction().getParent();
159     GVar = M->getGlobalVariable(Name);
160     if (GVar == nullptr) {
161       const Type *Ty = getTypeForSPIRVType(BaseType); // TODO: check type.
162       GVar = new GlobalVariable(*M, const_cast<Type *>(Ty), false,
163                                 GlobalValue::ExternalLinkage, nullptr,
164                                 Twine(Name));
165     }
166     GV = GVar;
167   }
168   Register Reg;
169   auto MIB = MIRBuilder.buildInstr(SPIRV::OpVariable)
170                  .addDef(ResVReg)
171                  .addUse(getSPIRVTypeID(BaseType))
172                  .addImm(static_cast<uint32_t>(Storage));
173 
174   if (Init != 0) {
175     MIB.addUse(Init->getOperand(0).getReg());
176   }
177 
178   // ISel may introduce a new register on this step, so we need to add it to
179   // DT and correct its type avoiding fails on the next stage.
180   if (IsInstSelector) {
181     const auto &Subtarget = CurMF->getSubtarget();
182     constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
183                                      *Subtarget.getRegisterInfo(),
184                                      *Subtarget.getRegBankInfo());
185   }
186   Reg = MIB->getOperand(0).getReg();
187 
188   // Set to Reg the same type as ResVReg has.
189   auto MRI = MIRBuilder.getMRI();
190   assert(MRI->getType(ResVReg).isPointer() && "Pointer type is expected");
191   if (Reg != ResVReg) {
192     LLT RegLLTy = LLT::pointer(MRI->getType(ResVReg).getAddressSpace(), 32);
193     MRI->setType(Reg, RegLLTy);
194     assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder);
195   }
196 
197   // If it's a global variable with name, output OpName for it.
198   if (GVar && GVar->hasName())
199     buildOpName(Reg, GVar->getName(), MIRBuilder);
200 
201   // Output decorations for the GV.
202   // TODO: maybe move to GenerateDecorations pass.
203   if (IsConst)
204     buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Constant, {});
205 
206   if (GVar && GVar->getAlign().valueOrOne().value() != 1)
207     buildOpDecorate(
208         Reg, MIRBuilder, SPIRV::Decoration::Alignment,
209         {static_cast<uint32_t>(GVar->getAlign().valueOrOne().value())});
210 
211   if (HasLinkageTy)
212     buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::LinkageAttributes,
213                     {static_cast<uint32_t>(LinkageType)}, Name);
214   return Reg;
215 }
216 
217 SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems,
218                                                SPIRVType *ElemType,
219                                                MachineIRBuilder &MIRBuilder,
220                                                bool EmitIR) {
221   assert((ElemType->getOpcode() != SPIRV::OpTypeVoid) &&
222          "Invalid array element type");
223   Register NumElementsVReg =
224       buildConstantInt(NumElems, MIRBuilder, nullptr, EmitIR);
225   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeArray)
226                  .addDef(createTypeVReg(MIRBuilder))
227                  .addUse(getSPIRVTypeID(ElemType))
228                  .addUse(NumElementsVReg);
229   return MIB;
230 }
231 
232 SPIRVType *SPIRVGlobalRegistry::getOpTypePointer(SPIRV::StorageClass SC,
233                                                  SPIRVType *ElemType,
234                                                  MachineIRBuilder &MIRBuilder) {
235   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypePointer)
236                  .addDef(createTypeVReg(MIRBuilder))
237                  .addImm(static_cast<uint32_t>(SC))
238                  .addUse(getSPIRVTypeID(ElemType));
239   return MIB;
240 }
241 
242 SPIRVType *SPIRVGlobalRegistry::getOpTypeFunction(
243     SPIRVType *RetType, const SmallVectorImpl<SPIRVType *> &ArgTypes,
244     MachineIRBuilder &MIRBuilder) {
245   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFunction)
246                  .addDef(createTypeVReg(MIRBuilder))
247                  .addUse(getSPIRVTypeID(RetType));
248   for (const SPIRVType *ArgType : ArgTypes)
249     MIB.addUse(getSPIRVTypeID(ArgType));
250   return MIB;
251 }
252 
253 SPIRVType *SPIRVGlobalRegistry::createSPIRVType(const Type *Ty,
254                                                 MachineIRBuilder &MIRBuilder,
255                                                 SPIRV::AccessQualifier AccQual,
256                                                 bool EmitIR) {
257   if (auto IType = dyn_cast<IntegerType>(Ty)) {
258     const unsigned Width = IType->getBitWidth();
259     return Width == 1 ? getOpTypeBool(MIRBuilder)
260                       : getOpTypeInt(Width, MIRBuilder, false);
261   }
262   if (Ty->isFloatingPointTy())
263     return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
264   if (Ty->isVoidTy())
265     return getOpTypeVoid(MIRBuilder);
266   if (Ty->isVectorTy()) {
267     auto El = getOrCreateSPIRVType(cast<FixedVectorType>(Ty)->getElementType(),
268                                    MIRBuilder);
269     return getOpTypeVector(cast<FixedVectorType>(Ty)->getNumElements(), El,
270                            MIRBuilder);
271   }
272   if (Ty->isArrayTy()) {
273     auto *El = getOrCreateSPIRVType(Ty->getArrayElementType(), MIRBuilder);
274     return getOpTypeArray(Ty->getArrayNumElements(), El, MIRBuilder, EmitIR);
275   }
276   assert(!isa<StructType>(Ty) && "Unsupported StructType");
277   if (auto FType = dyn_cast<FunctionType>(Ty)) {
278     SPIRVType *RetTy = getOrCreateSPIRVType(FType->getReturnType(), MIRBuilder);
279     SmallVector<SPIRVType *, 4> ParamTypes;
280     for (const auto &t : FType->params()) {
281       ParamTypes.push_back(getOrCreateSPIRVType(t, MIRBuilder));
282     }
283     return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder);
284   }
285   if (auto PType = dyn_cast<PointerType>(Ty)) {
286     Type *ElemType = PType->getPointerElementType();
287 
288     // Some OpenCL and SPIRV builtins like image2d_t are passed in as pointers,
289     // but should be treated as custom types like OpTypeImage.
290     assert(!isa<StructType>(ElemType) && "Unsupported StructType pointer");
291 
292     // Otherwise, treat it as a regular pointer type.
293     auto SC = addressSpaceToStorageClass(PType->getAddressSpace());
294     SPIRVType *SpvElementType = getOrCreateSPIRVType(
295         ElemType, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, EmitIR);
296     return getOpTypePointer(SC, SpvElementType, MIRBuilder);
297   }
298   llvm_unreachable("Unable to convert LLVM type to SPIRVType");
299 }
300 
301 SPIRVType *SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg) const {
302   auto t = VRegToTypeMap.find(CurMF);
303   if (t != VRegToTypeMap.end()) {
304     auto tt = t->second.find(VReg);
305     if (tt != t->second.end())
306       return tt->second;
307   }
308   return nullptr;
309 }
310 
311 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(
312     const Type *Type, MachineIRBuilder &MIRBuilder,
313     SPIRV::AccessQualifier AccessQual, bool EmitIR) {
314   SPIRVType *SpirvType = createSPIRVType(Type, MIRBuilder, AccessQual, EmitIR);
315   VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;
316   SPIRVToLLVMType[SpirvType] = Type;
317   return SpirvType;
318 }
319 
320 bool SPIRVGlobalRegistry::isScalarOfType(Register VReg,
321                                          unsigned TypeOpcode) const {
322   SPIRVType *Type = getSPIRVTypeForVReg(VReg);
323   assert(Type && "isScalarOfType VReg has no type assigned");
324   return Type->getOpcode() == TypeOpcode;
325 }
326 
327 bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg,
328                                                  unsigned TypeOpcode) const {
329   SPIRVType *Type = getSPIRVTypeForVReg(VReg);
330   assert(Type && "isScalarOrVectorOfType VReg has no type assigned");
331   if (Type->getOpcode() == TypeOpcode)
332     return true;
333   if (Type->getOpcode() == SPIRV::OpTypeVector) {
334     Register ScalarTypeVReg = Type->getOperand(1).getReg();
335     SPIRVType *ScalarType = getSPIRVTypeForVReg(ScalarTypeVReg);
336     return ScalarType->getOpcode() == TypeOpcode;
337   }
338   return false;
339 }
340 
341 unsigned
342 SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const {
343   assert(Type && "Invalid Type pointer");
344   if (Type->getOpcode() == SPIRV::OpTypeVector) {
345     auto EleTypeReg = Type->getOperand(1).getReg();
346     Type = getSPIRVTypeForVReg(EleTypeReg);
347   }
348   if (Type->getOpcode() == SPIRV::OpTypeInt ||
349       Type->getOpcode() == SPIRV::OpTypeFloat)
350     return Type->getOperand(1).getImm();
351   if (Type->getOpcode() == SPIRV::OpTypeBool)
352     return 1;
353   llvm_unreachable("Attempting to get bit width of non-integer/float type.");
354 }
355 
356 bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const {
357   assert(Type && "Invalid Type pointer");
358   if (Type->getOpcode() == SPIRV::OpTypeVector) {
359     auto EleTypeReg = Type->getOperand(1).getReg();
360     Type = getSPIRVTypeForVReg(EleTypeReg);
361   }
362   if (Type->getOpcode() == SPIRV::OpTypeInt)
363     return Type->getOperand(2).getImm() != 0;
364   llvm_unreachable("Attempting to get sign of non-integer type.");
365 }
366 
367 SPIRV::StorageClass
368 SPIRVGlobalRegistry::getPointerStorageClass(Register VReg) const {
369   SPIRVType *Type = getSPIRVTypeForVReg(VReg);
370   assert(Type && Type->getOpcode() == SPIRV::OpTypePointer &&
371          Type->getOperand(1).isImm() && "Pointer type is expected");
372   return static_cast<SPIRV::StorageClass>(Type->getOperand(1).getImm());
373 }
374 
375 SPIRVType *
376 SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(unsigned BitWidth,
377                                                  MachineIRBuilder &MIRBuilder) {
378   return getOrCreateSPIRVType(
379       IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), BitWidth),
380       MIRBuilder);
381 }
382 
383 SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(Type *LLVMTy,
384                                                       MachineInstrBuilder MIB) {
385   SPIRVType *SpirvType = MIB;
386   VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType;
387   SPIRVToLLVMType[SpirvType] = LLVMTy;
388   return SpirvType;
389 }
390 
391 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
392     unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
393   Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth);
394   MachineBasicBlock &BB = *I.getParent();
395   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeInt))
396                  .addDef(createTypeVReg(CurMF->getRegInfo()))
397                  .addImm(BitWidth)
398                  .addImm(0);
399   return restOfCreateSPIRVType(LLVMTy, MIB);
400 }
401 
402 SPIRVType *
403 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder) {
404   return getOrCreateSPIRVType(
405       IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), 1),
406       MIRBuilder);
407 }
408 
409 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
410     SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder) {
411   return getOrCreateSPIRVType(
412       FixedVectorType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)),
413                            NumElements),
414       MIRBuilder);
415 }
416 
417 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
418     SPIRVType *BaseType, unsigned NumElements, MachineInstr &I,
419     const SPIRVInstrInfo &TII) {
420   Type *LLVMTy = FixedVectorType::get(
421       const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements);
422   MachineBasicBlock &BB = *I.getParent();
423   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeVector))
424                  .addDef(createTypeVReg(CurMF->getRegInfo()))
425                  .addUse(getSPIRVTypeID(BaseType))
426                  .addImm(NumElements);
427   return restOfCreateSPIRVType(LLVMTy, MIB);
428 }
429 
430 SPIRVType *
431 SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(SPIRVType *BaseType,
432                                                  MachineIRBuilder &MIRBuilder,
433                                                  SPIRV::StorageClass SClass) {
434   return getOrCreateSPIRVType(
435       PointerType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)),
436                        storageClassToAddressSpace(SClass)),
437       MIRBuilder);
438 }
439 
440 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
441     SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &TII,
442     SPIRV::StorageClass SC) {
443   Type *LLVMTy =
444       PointerType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)),
445                        storageClassToAddressSpace(SC));
446   MachineBasicBlock &BB = *I.getParent();
447   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypePointer))
448                  .addDef(createTypeVReg(CurMF->getRegInfo()))
449                  .addImm(static_cast<uint32_t>(SC))
450                  .addUse(getSPIRVTypeID(BaseType));
451   return restOfCreateSPIRVType(LLVMTy, MIB);
452 }
453