1 //===- SPIRVModuleAnalysis.cpp - analysis of global instrs & regs - C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The analysis collects instructions that should be output at the module level
10 // and performs the global register numbering.
11 //
12 // The results of this analysis are used in AsmPrinter to rename registers
13 // globally and to output required instructions at the module level.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #include "SPIRVModuleAnalysis.h"
18 #include "MCTargetDesc/SPIRVBaseInfo.h"
19 #include "MCTargetDesc/SPIRVMCTargetDesc.h"
20 #include "SPIRV.h"
21 #include "SPIRVSubtarget.h"
22 #include "SPIRVTargetMachine.h"
23 #include "SPIRVUtils.h"
24 #include "TargetInfo/SPIRVTargetInfo.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/CodeGen/MachineModuleInfo.h"
27 #include "llvm/CodeGen/TargetPassConfig.h"
28
29 using namespace llvm;
30
31 #define DEBUG_TYPE "spirv-module-analysis"
32
33 static cl::opt<bool>
34 SPVDumpDeps("spv-dump-deps",
35 cl::desc("Dump MIR with SPIR-V dependencies info"),
36 cl::Optional, cl::init(false));
37
38 char llvm::SPIRVModuleAnalysis::ID = 0;
39
40 namespace llvm {
41 void initializeSPIRVModuleAnalysisPass(PassRegistry &);
42 } // namespace llvm
43
44 INITIALIZE_PASS(SPIRVModuleAnalysis, DEBUG_TYPE, "SPIRV module analysis", true,
45 true)
46
47 // Retrieve an unsigned from an MDNode with a list of them as operands.
getMetadataUInt(MDNode * MdNode,unsigned OpIndex,unsigned DefaultVal=0)48 static unsigned getMetadataUInt(MDNode *MdNode, unsigned OpIndex,
49 unsigned DefaultVal = 0) {
50 if (MdNode && OpIndex < MdNode->getNumOperands()) {
51 const auto &Op = MdNode->getOperand(OpIndex);
52 return mdconst::extract<ConstantInt>(Op)->getZExtValue();
53 }
54 return DefaultVal;
55 }
56
57 static SPIRV::Requirements
getSymbolicOperandRequirements(SPIRV::OperandCategory::OperandCategory Category,unsigned i,const SPIRVSubtarget & ST,SPIRV::RequirementHandler & Reqs)58 getSymbolicOperandRequirements(SPIRV::OperandCategory::OperandCategory Category,
59 unsigned i, const SPIRVSubtarget &ST,
60 SPIRV::RequirementHandler &Reqs) {
61 unsigned ReqMinVer = getSymbolicOperandMinVersion(Category, i);
62 unsigned ReqMaxVer = getSymbolicOperandMaxVersion(Category, i);
63 unsigned TargetVer = ST.getSPIRVVersion();
64 bool MinVerOK = !ReqMinVer || !TargetVer || TargetVer >= ReqMinVer;
65 bool MaxVerOK = !ReqMaxVer || !TargetVer || TargetVer <= ReqMaxVer;
66 CapabilityList ReqCaps = getSymbolicOperandCapabilities(Category, i);
67 ExtensionList ReqExts = getSymbolicOperandExtensions(Category, i);
68 if (ReqCaps.empty()) {
69 if (ReqExts.empty()) {
70 if (MinVerOK && MaxVerOK)
71 return {true, {}, {}, ReqMinVer, ReqMaxVer};
72 return {false, {}, {}, 0, 0};
73 }
74 } else if (MinVerOK && MaxVerOK) {
75 for (auto Cap : ReqCaps) { // Only need 1 of the capabilities to work.
76 if (Reqs.isCapabilityAvailable(Cap))
77 return {true, {Cap}, {}, ReqMinVer, ReqMaxVer};
78 }
79 }
80 // If there are no capabilities, or we can't satisfy the version or
81 // capability requirements, use the list of extensions (if the subtarget
82 // can handle them all).
83 if (llvm::all_of(ReqExts, [&ST](const SPIRV::Extension::Extension &Ext) {
84 return ST.canUseExtension(Ext);
85 })) {
86 return {true, {}, ReqExts, 0, 0}; // TODO: add versions to extensions.
87 }
88 return {false, {}, {}, 0, 0};
89 }
90
setBaseInfo(const Module & M)91 void SPIRVModuleAnalysis::setBaseInfo(const Module &M) {
92 MAI.MaxID = 0;
93 for (int i = 0; i < SPIRV::NUM_MODULE_SECTIONS; i++)
94 MAI.MS[i].clear();
95 MAI.RegisterAliasTable.clear();
96 MAI.InstrsToDelete.clear();
97 MAI.FuncMap.clear();
98 MAI.GlobalVarList.clear();
99 MAI.ExtInstSetMap.clear();
100 MAI.Reqs.clear();
101 MAI.Reqs.initAvailableCapabilities(*ST);
102
103 // TODO: determine memory model and source language from the configuratoin.
104 if (auto MemModel = M.getNamedMetadata("spirv.MemoryModel")) {
105 auto MemMD = MemModel->getOperand(0);
106 MAI.Addr = static_cast<SPIRV::AddressingModel::AddressingModel>(
107 getMetadataUInt(MemMD, 0));
108 MAI.Mem =
109 static_cast<SPIRV::MemoryModel::MemoryModel>(getMetadataUInt(MemMD, 1));
110 } else {
111 // TODO: Add support for VulkanMemoryModel.
112 MAI.Mem = ST->isOpenCLEnv() ? SPIRV::MemoryModel::OpenCL
113 : SPIRV::MemoryModel::GLSL450;
114 if (MAI.Mem == SPIRV::MemoryModel::OpenCL) {
115 unsigned PtrSize = ST->getPointerSize();
116 MAI.Addr = PtrSize == 32 ? SPIRV::AddressingModel::Physical32
117 : PtrSize == 64 ? SPIRV::AddressingModel::Physical64
118 : SPIRV::AddressingModel::Logical;
119 } else {
120 // TODO: Add support for PhysicalStorageBufferAddress.
121 MAI.Addr = SPIRV::AddressingModel::Logical;
122 }
123 }
124 // Get the OpenCL version number from metadata.
125 // TODO: support other source languages.
126 if (auto VerNode = M.getNamedMetadata("opencl.ocl.version")) {
127 MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_C;
128 // Construct version literal in accordance with SPIRV-LLVM-Translator.
129 // TODO: support multiple OCL version metadata.
130 assert(VerNode->getNumOperands() > 0 && "Invalid SPIR");
131 auto VersionMD = VerNode->getOperand(0);
132 unsigned MajorNum = getMetadataUInt(VersionMD, 0, 2);
133 unsigned MinorNum = getMetadataUInt(VersionMD, 1);
134 unsigned RevNum = getMetadataUInt(VersionMD, 2);
135 MAI.SrcLangVersion = (MajorNum * 100 + MinorNum) * 1000 + RevNum;
136 } else {
137 MAI.SrcLang = SPIRV::SourceLanguage::Unknown;
138 MAI.SrcLangVersion = 0;
139 }
140
141 if (auto ExtNode = M.getNamedMetadata("opencl.used.extensions")) {
142 for (unsigned I = 0, E = ExtNode->getNumOperands(); I != E; ++I) {
143 MDNode *MD = ExtNode->getOperand(I);
144 if (!MD || MD->getNumOperands() == 0)
145 continue;
146 for (unsigned J = 0, N = MD->getNumOperands(); J != N; ++J)
147 MAI.SrcExt.insert(cast<MDString>(MD->getOperand(J))->getString());
148 }
149 }
150
151 // Update required capabilities for this memory model, addressing model and
152 // source language.
153 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand,
154 MAI.Mem, *ST);
155 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::SourceLanguageOperand,
156 MAI.SrcLang, *ST);
157 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand,
158 MAI.Addr, *ST);
159
160 if (ST->isOpenCLEnv()) {
161 // TODO: check if it's required by default.
162 MAI.ExtInstSetMap[static_cast<unsigned>(
163 SPIRV::InstructionSet::OpenCL_std)] =
164 Register::index2VirtReg(MAI.getNextID());
165 }
166 }
167
168 // Collect MI which defines the register in the given machine function.
collectDefInstr(Register Reg,const MachineFunction * MF,SPIRV::ModuleAnalysisInfo * MAI,SPIRV::ModuleSectionType MSType,bool DoInsert=true)169 static void collectDefInstr(Register Reg, const MachineFunction *MF,
170 SPIRV::ModuleAnalysisInfo *MAI,
171 SPIRV::ModuleSectionType MSType,
172 bool DoInsert = true) {
173 assert(MAI->hasRegisterAlias(MF, Reg) && "Cannot find register alias");
174 MachineInstr *MI = MF->getRegInfo().getUniqueVRegDef(Reg);
175 assert(MI && "There should be an instruction that defines the register");
176 MAI->setSkipEmission(MI);
177 if (DoInsert)
178 MAI->MS[MSType].push_back(MI);
179 }
180
collectGlobalEntities(const std::vector<SPIRV::DTSortableEntry * > & DepsGraph,SPIRV::ModuleSectionType MSType,std::function<bool (const SPIRV::DTSortableEntry *)> Pred,bool UsePreOrder=false)181 void SPIRVModuleAnalysis::collectGlobalEntities(
182 const std::vector<SPIRV::DTSortableEntry *> &DepsGraph,
183 SPIRV::ModuleSectionType MSType,
184 std::function<bool(const SPIRV::DTSortableEntry *)> Pred,
185 bool UsePreOrder = false) {
186 DenseSet<const SPIRV::DTSortableEntry *> Visited;
187 for (const auto *E : DepsGraph) {
188 std::function<void(const SPIRV::DTSortableEntry *)> RecHoistUtil;
189 // NOTE: here we prefer recursive approach over iterative because
190 // we don't expect depchains long enough to cause SO.
191 RecHoistUtil = [MSType, UsePreOrder, &Visited, &Pred,
192 &RecHoistUtil](const SPIRV::DTSortableEntry *E) {
193 if (Visited.count(E) || !Pred(E))
194 return;
195 Visited.insert(E);
196
197 // Traversing deps graph in post-order allows us to get rid of
198 // register aliases preprocessing.
199 // But pre-order is required for correct processing of function
200 // declaration and arguments processing.
201 if (!UsePreOrder)
202 for (auto *S : E->getDeps())
203 RecHoistUtil(S);
204
205 Register GlobalReg = Register::index2VirtReg(MAI.getNextID());
206 bool IsFirst = true;
207 for (auto &U : *E) {
208 const MachineFunction *MF = U.first;
209 Register Reg = U.second;
210 MAI.setRegisterAlias(MF, Reg, GlobalReg);
211 if (!MF->getRegInfo().getUniqueVRegDef(Reg))
212 continue;
213 collectDefInstr(Reg, MF, &MAI, MSType, IsFirst);
214 IsFirst = false;
215 if (E->getIsGV())
216 MAI.GlobalVarList.push_back(MF->getRegInfo().getUniqueVRegDef(Reg));
217 }
218
219 if (UsePreOrder)
220 for (auto *S : E->getDeps())
221 RecHoistUtil(S);
222 };
223 RecHoistUtil(E);
224 }
225 }
226
227 // The function initializes global register alias table for types, consts,
228 // global vars and func decls and collects these instruction for output
229 // at module level. Also it collects explicit OpExtension/OpCapability
230 // instructions.
processDefInstrs(const Module & M)231 void SPIRVModuleAnalysis::processDefInstrs(const Module &M) {
232 std::vector<SPIRV::DTSortableEntry *> DepsGraph;
233
234 GR->buildDepsGraph(DepsGraph, SPVDumpDeps ? MMI : nullptr);
235
236 collectGlobalEntities(
237 DepsGraph, SPIRV::MB_TypeConstVars,
238 [](const SPIRV::DTSortableEntry *E) { return !E->getIsFunc(); });
239
240 for (auto F = M.begin(), E = M.end(); F != E; ++F) {
241 MachineFunction *MF = MMI->getMachineFunction(*F);
242 if (!MF)
243 continue;
244 // Iterate through and collect OpExtension/OpCapability instructions.
245 for (MachineBasicBlock &MBB : *MF) {
246 for (MachineInstr &MI : MBB) {
247 if (MI.getOpcode() == SPIRV::OpExtension) {
248 // Here, OpExtension just has a single enum operand, not a string.
249 auto Ext = SPIRV::Extension::Extension(MI.getOperand(0).getImm());
250 MAI.Reqs.addExtension(Ext);
251 MAI.setSkipEmission(&MI);
252 } else if (MI.getOpcode() == SPIRV::OpCapability) {
253 auto Cap = SPIRV::Capability::Capability(MI.getOperand(0).getImm());
254 MAI.Reqs.addCapability(Cap);
255 MAI.setSkipEmission(&MI);
256 }
257 }
258 }
259 }
260
261 collectGlobalEntities(
262 DepsGraph, SPIRV::MB_ExtFuncDecls,
263 [](const SPIRV::DTSortableEntry *E) { return E->getIsFunc(); }, true);
264 }
265
266 // Look for IDs declared with Import linkage, and map the corresponding function
267 // to the register defining that variable (which will usually be the result of
268 // an OpFunction). This lets us call externally imported functions using
269 // the correct ID registers.
collectFuncNames(MachineInstr & MI,const Function * F)270 void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI,
271 const Function *F) {
272 if (MI.getOpcode() == SPIRV::OpDecorate) {
273 // If it's got Import linkage.
274 auto Dec = MI.getOperand(1).getImm();
275 if (Dec == static_cast<unsigned>(SPIRV::Decoration::LinkageAttributes)) {
276 auto Lnk = MI.getOperand(MI.getNumOperands() - 1).getImm();
277 if (Lnk == static_cast<unsigned>(SPIRV::LinkageType::Import)) {
278 // Map imported function name to function ID register.
279 const Function *ImportedFunc =
280 F->getParent()->getFunction(getStringImm(MI, 2));
281 Register Target = MI.getOperand(0).getReg();
282 MAI.FuncMap[ImportedFunc] = MAI.getRegisterAlias(MI.getMF(), Target);
283 }
284 }
285 } else if (MI.getOpcode() == SPIRV::OpFunction) {
286 // Record all internal OpFunction declarations.
287 Register Reg = MI.defs().begin()->getReg();
288 Register GlobalReg = MAI.getRegisterAlias(MI.getMF(), Reg);
289 assert(GlobalReg.isValid());
290 MAI.FuncMap[F] = GlobalReg;
291 }
292 }
293
294 using InstrSignature = SmallVector<size_t>;
295 using InstrTraces = std::set<InstrSignature>;
296
297 // Returns a representation of an instruction as a vector of MachineOperand
298 // hash values, see llvm::hash_value(const MachineOperand &MO) for details.
299 // This creates a signature of the instruction with the same content
300 // that MachineOperand::isIdenticalTo uses for comparison.
instrToSignature(MachineInstr & MI,SPIRV::ModuleAnalysisInfo & MAI)301 static InstrSignature instrToSignature(MachineInstr &MI,
302 SPIRV::ModuleAnalysisInfo &MAI) {
303 InstrSignature Signature;
304 for (unsigned i = 0; i < MI.getNumOperands(); ++i) {
305 const MachineOperand &MO = MI.getOperand(i);
306 size_t h;
307 if (MO.isReg()) {
308 Register RegAlias = MAI.getRegisterAlias(MI.getMF(), MO.getReg());
309 // mimic llvm::hash_value(const MachineOperand &MO)
310 h = hash_combine(MO.getType(), (unsigned)RegAlias, MO.getSubReg(),
311 MO.isDef());
312 } else {
313 h = hash_value(MO);
314 }
315 Signature.push_back(h);
316 }
317 return Signature;
318 }
319
320 // Collect the given instruction in the specified MS. We assume global register
321 // numbering has already occurred by this point. We can directly compare reg
322 // arguments when detecting duplicates.
collectOtherInstr(MachineInstr & MI,SPIRV::ModuleAnalysisInfo & MAI,SPIRV::ModuleSectionType MSType,InstrTraces & IS,bool Append=true)323 static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI,
324 SPIRV::ModuleSectionType MSType, InstrTraces &IS,
325 bool Append = true) {
326 MAI.setSkipEmission(&MI);
327 InstrSignature MISign = instrToSignature(MI, MAI);
328 auto FoundMI = IS.insert(MISign);
329 if (!FoundMI.second)
330 return; // insert failed, so we found a duplicate; don't add it to MAI.MS
331 // No duplicates, so add it.
332 if (Append)
333 MAI.MS[MSType].push_back(&MI);
334 else
335 MAI.MS[MSType].insert(MAI.MS[MSType].begin(), &MI);
336 }
337
338 // Some global instructions make reference to function-local ID regs, so cannot
339 // be correctly collected until these registers are globally numbered.
processOtherInstrs(const Module & M)340 void SPIRVModuleAnalysis::processOtherInstrs(const Module &M) {
341 InstrTraces IS;
342 for (auto F = M.begin(), E = M.end(); F != E; ++F) {
343 if ((*F).isDeclaration())
344 continue;
345 MachineFunction *MF = MMI->getMachineFunction(*F);
346 assert(MF);
347 for (MachineBasicBlock &MBB : *MF)
348 for (MachineInstr &MI : MBB) {
349 if (MAI.getSkipEmission(&MI))
350 continue;
351 const unsigned OpCode = MI.getOpcode();
352 if (OpCode == SPIRV::OpName || OpCode == SPIRV::OpMemberName) {
353 collectOtherInstr(MI, MAI, SPIRV::MB_DebugNames, IS);
354 } else if (OpCode == SPIRV::OpEntryPoint) {
355 collectOtherInstr(MI, MAI, SPIRV::MB_EntryPoints, IS);
356 } else if (TII->isDecorationInstr(MI)) {
357 collectOtherInstr(MI, MAI, SPIRV::MB_Annotations, IS);
358 collectFuncNames(MI, &*F);
359 } else if (TII->isConstantInstr(MI)) {
360 // Now OpSpecConstant*s are not in DT,
361 // but they need to be collected anyway.
362 collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS);
363 } else if (OpCode == SPIRV::OpFunction) {
364 collectFuncNames(MI, &*F);
365 } else if (OpCode == SPIRV::OpTypeForwardPointer) {
366 collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS, false);
367 }
368 }
369 }
370 }
371
372 // Number registers in all functions globally from 0 onwards and store
373 // the result in global register alias table. Some registers are already
374 // numbered in collectGlobalEntities.
numberRegistersGlobally(const Module & M)375 void SPIRVModuleAnalysis::numberRegistersGlobally(const Module &M) {
376 for (auto F = M.begin(), E = M.end(); F != E; ++F) {
377 if ((*F).isDeclaration())
378 continue;
379 MachineFunction *MF = MMI->getMachineFunction(*F);
380 assert(MF);
381 for (MachineBasicBlock &MBB : *MF) {
382 for (MachineInstr &MI : MBB) {
383 for (MachineOperand &Op : MI.operands()) {
384 if (!Op.isReg())
385 continue;
386 Register Reg = Op.getReg();
387 if (MAI.hasRegisterAlias(MF, Reg))
388 continue;
389 Register NewReg = Register::index2VirtReg(MAI.getNextID());
390 MAI.setRegisterAlias(MF, Reg, NewReg);
391 }
392 if (MI.getOpcode() != SPIRV::OpExtInst)
393 continue;
394 auto Set = MI.getOperand(2).getImm();
395 if (!MAI.ExtInstSetMap.contains(Set))
396 MAI.ExtInstSetMap[Set] = Register::index2VirtReg(MAI.getNextID());
397 }
398 }
399 }
400 }
401
402 // RequirementHandler implementations.
getAndAddRequirements(SPIRV::OperandCategory::OperandCategory Category,uint32_t i,const SPIRVSubtarget & ST)403 void SPIRV::RequirementHandler::getAndAddRequirements(
404 SPIRV::OperandCategory::OperandCategory Category, uint32_t i,
405 const SPIRVSubtarget &ST) {
406 addRequirements(getSymbolicOperandRequirements(Category, i, ST, *this));
407 }
408
pruneCapabilities(const CapabilityList & ToPrune)409 void SPIRV::RequirementHandler::pruneCapabilities(
410 const CapabilityList &ToPrune) {
411 for (const auto &Cap : ToPrune) {
412 AllCaps.insert(Cap);
413 auto FoundIndex = llvm::find(MinimalCaps, Cap);
414 if (FoundIndex != MinimalCaps.end())
415 MinimalCaps.erase(FoundIndex);
416 CapabilityList ImplicitDecls =
417 getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap);
418 pruneCapabilities(ImplicitDecls);
419 }
420 }
421
addCapabilities(const CapabilityList & ToAdd)422 void SPIRV::RequirementHandler::addCapabilities(const CapabilityList &ToAdd) {
423 for (const auto &Cap : ToAdd) {
424 bool IsNewlyInserted = AllCaps.insert(Cap).second;
425 if (!IsNewlyInserted) // Don't re-add if it's already been declared.
426 continue;
427 CapabilityList ImplicitDecls =
428 getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap);
429 pruneCapabilities(ImplicitDecls);
430 MinimalCaps.push_back(Cap);
431 }
432 }
433
addRequirements(const SPIRV::Requirements & Req)434 void SPIRV::RequirementHandler::addRequirements(
435 const SPIRV::Requirements &Req) {
436 if (!Req.IsSatisfiable)
437 report_fatal_error("Adding SPIR-V requirements this target can't satisfy.");
438
439 if (Req.Cap.has_value())
440 addCapabilities({Req.Cap.value()});
441
442 addExtensions(Req.Exts);
443
444 if (Req.MinVer) {
445 if (MaxVersion && Req.MinVer > MaxVersion) {
446 LLVM_DEBUG(dbgs() << "Conflicting version requirements: >= " << Req.MinVer
447 << " and <= " << MaxVersion << "\n");
448 report_fatal_error("Adding SPIR-V requirements that can't be satisfied.");
449 }
450
451 if (MinVersion == 0 || Req.MinVer > MinVersion)
452 MinVersion = Req.MinVer;
453 }
454
455 if (Req.MaxVer) {
456 if (MinVersion && Req.MaxVer < MinVersion) {
457 LLVM_DEBUG(dbgs() << "Conflicting version requirements: <= " << Req.MaxVer
458 << " and >= " << MinVersion << "\n");
459 report_fatal_error("Adding SPIR-V requirements that can't be satisfied.");
460 }
461
462 if (MaxVersion == 0 || Req.MaxVer < MaxVersion)
463 MaxVersion = Req.MaxVer;
464 }
465 }
466
checkSatisfiable(const SPIRVSubtarget & ST) const467 void SPIRV::RequirementHandler::checkSatisfiable(
468 const SPIRVSubtarget &ST) const {
469 // Report as many errors as possible before aborting the compilation.
470 bool IsSatisfiable = true;
471 auto TargetVer = ST.getSPIRVVersion();
472
473 if (MaxVersion && TargetVer && MaxVersion < TargetVer) {
474 LLVM_DEBUG(
475 dbgs() << "Target SPIR-V version too high for required features\n"
476 << "Required max version: " << MaxVersion << " target version "
477 << TargetVer << "\n");
478 IsSatisfiable = false;
479 }
480
481 if (MinVersion && TargetVer && MinVersion > TargetVer) {
482 LLVM_DEBUG(dbgs() << "Target SPIR-V version too low for required features\n"
483 << "Required min version: " << MinVersion
484 << " target version " << TargetVer << "\n");
485 IsSatisfiable = false;
486 }
487
488 if (MinVersion && MaxVersion && MinVersion > MaxVersion) {
489 LLVM_DEBUG(
490 dbgs()
491 << "Version is too low for some features and too high for others.\n"
492 << "Required SPIR-V min version: " << MinVersion
493 << " required SPIR-V max version " << MaxVersion << "\n");
494 IsSatisfiable = false;
495 }
496
497 for (auto Cap : MinimalCaps) {
498 if (AvailableCaps.contains(Cap))
499 continue;
500 LLVM_DEBUG(dbgs() << "Capability not supported: "
501 << getSymbolicOperandMnemonic(
502 OperandCategory::CapabilityOperand, Cap)
503 << "\n");
504 IsSatisfiable = false;
505 }
506
507 for (auto Ext : AllExtensions) {
508 if (ST.canUseExtension(Ext))
509 continue;
510 LLVM_DEBUG(dbgs() << "Extension not supported: "
511 << getSymbolicOperandMnemonic(
512 OperandCategory::ExtensionOperand, Ext)
513 << "\n");
514 IsSatisfiable = false;
515 }
516
517 if (!IsSatisfiable)
518 report_fatal_error("Unable to meet SPIR-V requirements for this target.");
519 }
520
521 // Add the given capabilities and all their implicitly defined capabilities too.
addAvailableCaps(const CapabilityList & ToAdd)522 void SPIRV::RequirementHandler::addAvailableCaps(const CapabilityList &ToAdd) {
523 for (const auto Cap : ToAdd)
524 if (AvailableCaps.insert(Cap).second)
525 addAvailableCaps(getSymbolicOperandCapabilities(
526 SPIRV::OperandCategory::CapabilityOperand, Cap));
527 }
528
removeCapabilityIf(const Capability::Capability ToRemove,const Capability::Capability IfPresent)529 void SPIRV::RequirementHandler::removeCapabilityIf(
530 const Capability::Capability ToRemove,
531 const Capability::Capability IfPresent) {
532 if (AllCaps.contains(IfPresent))
533 AllCaps.erase(ToRemove);
534 }
535
536 namespace llvm {
537 namespace SPIRV {
initAvailableCapabilities(const SPIRVSubtarget & ST)538 void RequirementHandler::initAvailableCapabilities(const SPIRVSubtarget &ST) {
539 if (ST.isOpenCLEnv()) {
540 initAvailableCapabilitiesForOpenCL(ST);
541 return;
542 }
543
544 if (ST.isVulkanEnv()) {
545 initAvailableCapabilitiesForVulkan(ST);
546 return;
547 }
548
549 report_fatal_error("Unimplemented environment for SPIR-V generation.");
550 }
551
initAvailableCapabilitiesForOpenCL(const SPIRVSubtarget & ST)552 void RequirementHandler::initAvailableCapabilitiesForOpenCL(
553 const SPIRVSubtarget &ST) {
554 // Add the min requirements for different OpenCL and SPIR-V versions.
555 addAvailableCaps({Capability::Addresses, Capability::Float16Buffer,
556 Capability::Int16, Capability::Int8, Capability::Kernel,
557 Capability::Linkage, Capability::Vector16,
558 Capability::Groups, Capability::GenericPointer,
559 Capability::Shader});
560 if (ST.hasOpenCLFullProfile())
561 addAvailableCaps({Capability::Int64, Capability::Int64Atomics});
562 if (ST.hasOpenCLImageSupport()) {
563 addAvailableCaps({Capability::ImageBasic, Capability::LiteralSampler,
564 Capability::Image1D, Capability::SampledBuffer,
565 Capability::ImageBuffer});
566 if (ST.isAtLeastOpenCLVer(20))
567 addAvailableCaps({Capability::ImageReadWrite});
568 }
569 if (ST.isAtLeastSPIRVVer(11) && ST.isAtLeastOpenCLVer(22))
570 addAvailableCaps({Capability::SubgroupDispatch, Capability::PipeStorage});
571 if (ST.isAtLeastSPIRVVer(13))
572 addAvailableCaps({Capability::GroupNonUniform,
573 Capability::GroupNonUniformVote,
574 Capability::GroupNonUniformArithmetic,
575 Capability::GroupNonUniformBallot,
576 Capability::GroupNonUniformClustered,
577 Capability::GroupNonUniformShuffle,
578 Capability::GroupNonUniformShuffleRelative});
579 if (ST.isAtLeastSPIRVVer(14))
580 addAvailableCaps({Capability::DenormPreserve, Capability::DenormFlushToZero,
581 Capability::SignedZeroInfNanPreserve,
582 Capability::RoundingModeRTE,
583 Capability::RoundingModeRTZ});
584 // TODO: verify if this needs some checks.
585 addAvailableCaps({Capability::Float16, Capability::Float64});
586
587 // Add capabilities enabled by extensions.
588 for (auto Extension : ST.getAllAvailableExtensions()) {
589 CapabilityList EnabledCapabilities =
590 getCapabilitiesEnabledByExtension(Extension);
591 addAvailableCaps(EnabledCapabilities);
592 }
593
594 // TODO: add OpenCL extensions.
595 }
596
initAvailableCapabilitiesForVulkan(const SPIRVSubtarget & ST)597 void RequirementHandler::initAvailableCapabilitiesForVulkan(
598 const SPIRVSubtarget &ST) {
599 addAvailableCaps({Capability::Shader, Capability::Linkage});
600
601 // Provided by all supported Vulkan versions.
602 addAvailableCaps({Capability::Int16, Capability::Int64, Capability::Float16,
603 Capability::Float64});
604 }
605
606 } // namespace SPIRV
607 } // namespace llvm
608
609 // Add the required capabilities from a decoration instruction (including
610 // BuiltIns).
addOpDecorateReqs(const MachineInstr & MI,unsigned DecIndex,SPIRV::RequirementHandler & Reqs,const SPIRVSubtarget & ST)611 static void addOpDecorateReqs(const MachineInstr &MI, unsigned DecIndex,
612 SPIRV::RequirementHandler &Reqs,
613 const SPIRVSubtarget &ST) {
614 int64_t DecOp = MI.getOperand(DecIndex).getImm();
615 auto Dec = static_cast<SPIRV::Decoration::Decoration>(DecOp);
616 Reqs.addRequirements(getSymbolicOperandRequirements(
617 SPIRV::OperandCategory::DecorationOperand, Dec, ST, Reqs));
618
619 if (Dec == SPIRV::Decoration::BuiltIn) {
620 int64_t BuiltInOp = MI.getOperand(DecIndex + 1).getImm();
621 auto BuiltIn = static_cast<SPIRV::BuiltIn::BuiltIn>(BuiltInOp);
622 Reqs.addRequirements(getSymbolicOperandRequirements(
623 SPIRV::OperandCategory::BuiltInOperand, BuiltIn, ST, Reqs));
624 }
625 }
626
627 // Add requirements for image handling.
addOpTypeImageReqs(const MachineInstr & MI,SPIRV::RequirementHandler & Reqs,const SPIRVSubtarget & ST)628 static void addOpTypeImageReqs(const MachineInstr &MI,
629 SPIRV::RequirementHandler &Reqs,
630 const SPIRVSubtarget &ST) {
631 assert(MI.getNumOperands() >= 8 && "Insufficient operands for OpTypeImage");
632 // The operand indices used here are based on the OpTypeImage layout, which
633 // the MachineInstr follows as well.
634 int64_t ImgFormatOp = MI.getOperand(7).getImm();
635 auto ImgFormat = static_cast<SPIRV::ImageFormat::ImageFormat>(ImgFormatOp);
636 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ImageFormatOperand,
637 ImgFormat, ST);
638
639 bool IsArrayed = MI.getOperand(4).getImm() == 1;
640 bool IsMultisampled = MI.getOperand(5).getImm() == 1;
641 bool NoSampler = MI.getOperand(6).getImm() == 2;
642 // Add dimension requirements.
643 assert(MI.getOperand(2).isImm());
644 switch (MI.getOperand(2).getImm()) {
645 case SPIRV::Dim::DIM_1D:
646 Reqs.addRequirements(NoSampler ? SPIRV::Capability::Image1D
647 : SPIRV::Capability::Sampled1D);
648 break;
649 case SPIRV::Dim::DIM_2D:
650 if (IsMultisampled && NoSampler)
651 Reqs.addRequirements(SPIRV::Capability::ImageMSArray);
652 break;
653 case SPIRV::Dim::DIM_Cube:
654 Reqs.addRequirements(SPIRV::Capability::Shader);
655 if (IsArrayed)
656 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageCubeArray
657 : SPIRV::Capability::SampledCubeArray);
658 break;
659 case SPIRV::Dim::DIM_Rect:
660 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageRect
661 : SPIRV::Capability::SampledRect);
662 break;
663 case SPIRV::Dim::DIM_Buffer:
664 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageBuffer
665 : SPIRV::Capability::SampledBuffer);
666 break;
667 case SPIRV::Dim::DIM_SubpassData:
668 Reqs.addRequirements(SPIRV::Capability::InputAttachment);
669 break;
670 }
671
672 // Has optional access qualifier.
673 // TODO: check if it's OpenCL's kernel.
674 if (MI.getNumOperands() > 8 &&
675 MI.getOperand(8).getImm() == SPIRV::AccessQualifier::ReadWrite)
676 Reqs.addRequirements(SPIRV::Capability::ImageReadWrite);
677 else
678 Reqs.addRequirements(SPIRV::Capability::ImageBasic);
679 }
680
addInstrRequirements(const MachineInstr & MI,SPIRV::RequirementHandler & Reqs,const SPIRVSubtarget & ST)681 void addInstrRequirements(const MachineInstr &MI,
682 SPIRV::RequirementHandler &Reqs,
683 const SPIRVSubtarget &ST) {
684 switch (MI.getOpcode()) {
685 case SPIRV::OpMemoryModel: {
686 int64_t Addr = MI.getOperand(0).getImm();
687 Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand,
688 Addr, ST);
689 int64_t Mem = MI.getOperand(1).getImm();
690 Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand, Mem,
691 ST);
692 break;
693 }
694 case SPIRV::OpEntryPoint: {
695 int64_t Exe = MI.getOperand(0).getImm();
696 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModelOperand,
697 Exe, ST);
698 break;
699 }
700 case SPIRV::OpExecutionMode:
701 case SPIRV::OpExecutionModeId: {
702 int64_t Exe = MI.getOperand(1).getImm();
703 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModeOperand,
704 Exe, ST);
705 break;
706 }
707 case SPIRV::OpTypeMatrix:
708 Reqs.addCapability(SPIRV::Capability::Matrix);
709 break;
710 case SPIRV::OpTypeInt: {
711 unsigned BitWidth = MI.getOperand(1).getImm();
712 if (BitWidth == 64)
713 Reqs.addCapability(SPIRV::Capability::Int64);
714 else if (BitWidth == 16)
715 Reqs.addCapability(SPIRV::Capability::Int16);
716 else if (BitWidth == 8)
717 Reqs.addCapability(SPIRV::Capability::Int8);
718 break;
719 }
720 case SPIRV::OpTypeFloat: {
721 unsigned BitWidth = MI.getOperand(1).getImm();
722 if (BitWidth == 64)
723 Reqs.addCapability(SPIRV::Capability::Float64);
724 else if (BitWidth == 16)
725 Reqs.addCapability(SPIRV::Capability::Float16);
726 break;
727 }
728 case SPIRV::OpTypeVector: {
729 unsigned NumComponents = MI.getOperand(2).getImm();
730 if (NumComponents == 8 || NumComponents == 16)
731 Reqs.addCapability(SPIRV::Capability::Vector16);
732 break;
733 }
734 case SPIRV::OpTypePointer: {
735 auto SC = MI.getOperand(1).getImm();
736 Reqs.getAndAddRequirements(SPIRV::OperandCategory::StorageClassOperand, SC,
737 ST);
738 // If it's a type of pointer to float16 targeting OpenCL, add Float16Buffer
739 // capability.
740 if (!ST.isOpenCLEnv())
741 break;
742 assert(MI.getOperand(2).isReg());
743 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
744 SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(2).getReg());
745 if (TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
746 TypeDef->getOperand(1).getImm() == 16)
747 Reqs.addCapability(SPIRV::Capability::Float16Buffer);
748 break;
749 }
750 case SPIRV::OpBitReverse:
751 case SPIRV::OpBitFieldInsert:
752 case SPIRV::OpBitFieldSExtract:
753 case SPIRV::OpBitFieldUExtract:
754 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions)) {
755 Reqs.addCapability(SPIRV::Capability::Shader);
756 break;
757 }
758 Reqs.addExtension(SPIRV::Extension::SPV_KHR_bit_instructions);
759 Reqs.addCapability(SPIRV::Capability::BitInstructions);
760 break;
761 case SPIRV::OpTypeRuntimeArray:
762 Reqs.addCapability(SPIRV::Capability::Shader);
763 break;
764 case SPIRV::OpTypeOpaque:
765 case SPIRV::OpTypeEvent:
766 Reqs.addCapability(SPIRV::Capability::Kernel);
767 break;
768 case SPIRV::OpTypePipe:
769 case SPIRV::OpTypeReserveId:
770 Reqs.addCapability(SPIRV::Capability::Pipes);
771 break;
772 case SPIRV::OpTypeDeviceEvent:
773 case SPIRV::OpTypeQueue:
774 case SPIRV::OpBuildNDRange:
775 Reqs.addCapability(SPIRV::Capability::DeviceEnqueue);
776 break;
777 case SPIRV::OpDecorate:
778 case SPIRV::OpDecorateId:
779 case SPIRV::OpDecorateString:
780 addOpDecorateReqs(MI, 1, Reqs, ST);
781 break;
782 case SPIRV::OpMemberDecorate:
783 case SPIRV::OpMemberDecorateString:
784 addOpDecorateReqs(MI, 2, Reqs, ST);
785 break;
786 case SPIRV::OpInBoundsPtrAccessChain:
787 Reqs.addCapability(SPIRV::Capability::Addresses);
788 break;
789 case SPIRV::OpConstantSampler:
790 Reqs.addCapability(SPIRV::Capability::LiteralSampler);
791 break;
792 case SPIRV::OpTypeImage:
793 addOpTypeImageReqs(MI, Reqs, ST);
794 break;
795 case SPIRV::OpTypeSampler:
796 Reqs.addCapability(SPIRV::Capability::ImageBasic);
797 break;
798 case SPIRV::OpTypeForwardPointer:
799 // TODO: check if it's OpenCL's kernel.
800 Reqs.addCapability(SPIRV::Capability::Addresses);
801 break;
802 case SPIRV::OpAtomicFlagTestAndSet:
803 case SPIRV::OpAtomicLoad:
804 case SPIRV::OpAtomicStore:
805 case SPIRV::OpAtomicExchange:
806 case SPIRV::OpAtomicCompareExchange:
807 case SPIRV::OpAtomicIIncrement:
808 case SPIRV::OpAtomicIDecrement:
809 case SPIRV::OpAtomicIAdd:
810 case SPIRV::OpAtomicISub:
811 case SPIRV::OpAtomicUMin:
812 case SPIRV::OpAtomicUMax:
813 case SPIRV::OpAtomicSMin:
814 case SPIRV::OpAtomicSMax:
815 case SPIRV::OpAtomicAnd:
816 case SPIRV::OpAtomicOr:
817 case SPIRV::OpAtomicXor: {
818 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
819 const MachineInstr *InstrPtr = &MI;
820 if (MI.getOpcode() == SPIRV::OpAtomicStore) {
821 assert(MI.getOperand(3).isReg());
822 InstrPtr = MRI.getVRegDef(MI.getOperand(3).getReg());
823 assert(InstrPtr && "Unexpected type instruction for OpAtomicStore");
824 }
825 assert(InstrPtr->getOperand(1).isReg() && "Unexpected operand in atomic");
826 Register TypeReg = InstrPtr->getOperand(1).getReg();
827 SPIRVType *TypeDef = MRI.getVRegDef(TypeReg);
828 if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
829 unsigned BitWidth = TypeDef->getOperand(1).getImm();
830 if (BitWidth == 64)
831 Reqs.addCapability(SPIRV::Capability::Int64Atomics);
832 }
833 break;
834 }
835 case SPIRV::OpGroupNonUniformIAdd:
836 case SPIRV::OpGroupNonUniformFAdd:
837 case SPIRV::OpGroupNonUniformIMul:
838 case SPIRV::OpGroupNonUniformFMul:
839 case SPIRV::OpGroupNonUniformSMin:
840 case SPIRV::OpGroupNonUniformUMin:
841 case SPIRV::OpGroupNonUniformFMin:
842 case SPIRV::OpGroupNonUniformSMax:
843 case SPIRV::OpGroupNonUniformUMax:
844 case SPIRV::OpGroupNonUniformFMax:
845 case SPIRV::OpGroupNonUniformBitwiseAnd:
846 case SPIRV::OpGroupNonUniformBitwiseOr:
847 case SPIRV::OpGroupNonUniformBitwiseXor:
848 case SPIRV::OpGroupNonUniformLogicalAnd:
849 case SPIRV::OpGroupNonUniformLogicalOr:
850 case SPIRV::OpGroupNonUniformLogicalXor: {
851 assert(MI.getOperand(3).isImm());
852 int64_t GroupOp = MI.getOperand(3).getImm();
853 switch (GroupOp) {
854 case SPIRV::GroupOperation::Reduce:
855 case SPIRV::GroupOperation::InclusiveScan:
856 case SPIRV::GroupOperation::ExclusiveScan:
857 Reqs.addCapability(SPIRV::Capability::Kernel);
858 Reqs.addCapability(SPIRV::Capability::GroupNonUniformArithmetic);
859 Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot);
860 break;
861 case SPIRV::GroupOperation::ClusteredReduce:
862 Reqs.addCapability(SPIRV::Capability::GroupNonUniformClustered);
863 break;
864 case SPIRV::GroupOperation::PartitionedReduceNV:
865 case SPIRV::GroupOperation::PartitionedInclusiveScanNV:
866 case SPIRV::GroupOperation::PartitionedExclusiveScanNV:
867 Reqs.addCapability(SPIRV::Capability::GroupNonUniformPartitionedNV);
868 break;
869 }
870 break;
871 }
872 case SPIRV::OpGroupNonUniformShuffle:
873 case SPIRV::OpGroupNonUniformShuffleXor:
874 Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffle);
875 break;
876 case SPIRV::OpGroupNonUniformShuffleUp:
877 case SPIRV::OpGroupNonUniformShuffleDown:
878 Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffleRelative);
879 break;
880 case SPIRV::OpGroupAll:
881 case SPIRV::OpGroupAny:
882 case SPIRV::OpGroupBroadcast:
883 case SPIRV::OpGroupIAdd:
884 case SPIRV::OpGroupFAdd:
885 case SPIRV::OpGroupFMin:
886 case SPIRV::OpGroupUMin:
887 case SPIRV::OpGroupSMin:
888 case SPIRV::OpGroupFMax:
889 case SPIRV::OpGroupUMax:
890 case SPIRV::OpGroupSMax:
891 Reqs.addCapability(SPIRV::Capability::Groups);
892 break;
893 case SPIRV::OpGroupNonUniformElect:
894 Reqs.addCapability(SPIRV::Capability::GroupNonUniform);
895 break;
896 case SPIRV::OpGroupNonUniformAll:
897 case SPIRV::OpGroupNonUniformAny:
898 case SPIRV::OpGroupNonUniformAllEqual:
899 Reqs.addCapability(SPIRV::Capability::GroupNonUniformVote);
900 break;
901 case SPIRV::OpGroupNonUniformBroadcast:
902 case SPIRV::OpGroupNonUniformBroadcastFirst:
903 case SPIRV::OpGroupNonUniformBallot:
904 case SPIRV::OpGroupNonUniformInverseBallot:
905 case SPIRV::OpGroupNonUniformBallotBitExtract:
906 case SPIRV::OpGroupNonUniformBallotBitCount:
907 case SPIRV::OpGroupNonUniformBallotFindLSB:
908 case SPIRV::OpGroupNonUniformBallotFindMSB:
909 Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot);
910 break;
911 case SPIRV::OpAssumeTrueKHR:
912 case SPIRV::OpExpectKHR:
913 if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume)) {
914 Reqs.addExtension(SPIRV::Extension::SPV_KHR_expect_assume);
915 Reqs.addCapability(SPIRV::Capability::ExpectAssumeKHR);
916 }
917 break;
918 default:
919 break;
920 }
921
922 // If we require capability Shader, then we can remove the requirement for
923 // the BitInstructions capability, since Shader is a superset capability
924 // of BitInstructions.
925 Reqs.removeCapabilityIf(SPIRV::Capability::BitInstructions,
926 SPIRV::Capability::Shader);
927 }
928
collectReqs(const Module & M,SPIRV::ModuleAnalysisInfo & MAI,MachineModuleInfo * MMI,const SPIRVSubtarget & ST)929 static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI,
930 MachineModuleInfo *MMI, const SPIRVSubtarget &ST) {
931 // Collect requirements for existing instructions.
932 for (auto F = M.begin(), E = M.end(); F != E; ++F) {
933 MachineFunction *MF = MMI->getMachineFunction(*F);
934 if (!MF)
935 continue;
936 for (const MachineBasicBlock &MBB : *MF)
937 for (const MachineInstr &MI : MBB)
938 addInstrRequirements(MI, MAI.Reqs, ST);
939 }
940 // Collect requirements for OpExecutionMode instructions.
941 auto Node = M.getNamedMetadata("spirv.ExecutionMode");
942 if (Node) {
943 for (unsigned i = 0; i < Node->getNumOperands(); i++) {
944 MDNode *MDN = cast<MDNode>(Node->getOperand(i));
945 const MDOperand &MDOp = MDN->getOperand(1);
946 if (auto *CMeta = dyn_cast<ConstantAsMetadata>(MDOp)) {
947 Constant *C = CMeta->getValue();
948 if (ConstantInt *Const = dyn_cast<ConstantInt>(C)) {
949 auto EM = Const->getZExtValue();
950 MAI.Reqs.getAndAddRequirements(
951 SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
952 }
953 }
954 }
955 }
956 for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) {
957 const Function &F = *FI;
958 if (F.isDeclaration())
959 continue;
960 if (F.getMetadata("reqd_work_group_size"))
961 MAI.Reqs.getAndAddRequirements(
962 SPIRV::OperandCategory::ExecutionModeOperand,
963 SPIRV::ExecutionMode::LocalSize, ST);
964 if (F.getFnAttribute("hlsl.numthreads").isValid()) {
965 MAI.Reqs.getAndAddRequirements(
966 SPIRV::OperandCategory::ExecutionModeOperand,
967 SPIRV::ExecutionMode::LocalSize, ST);
968 }
969 if (F.getMetadata("work_group_size_hint"))
970 MAI.Reqs.getAndAddRequirements(
971 SPIRV::OperandCategory::ExecutionModeOperand,
972 SPIRV::ExecutionMode::LocalSizeHint, ST);
973 if (F.getMetadata("intel_reqd_sub_group_size"))
974 MAI.Reqs.getAndAddRequirements(
975 SPIRV::OperandCategory::ExecutionModeOperand,
976 SPIRV::ExecutionMode::SubgroupSize, ST);
977 if (F.getMetadata("vec_type_hint"))
978 MAI.Reqs.getAndAddRequirements(
979 SPIRV::OperandCategory::ExecutionModeOperand,
980 SPIRV::ExecutionMode::VecTypeHint, ST);
981
982 if (F.hasOptNone() &&
983 ST.canUseExtension(SPIRV::Extension::SPV_INTEL_optnone)) {
984 // Output OpCapability OptNoneINTEL.
985 MAI.Reqs.addExtension(SPIRV::Extension::SPV_INTEL_optnone);
986 MAI.Reqs.addCapability(SPIRV::Capability::OptNoneINTEL);
987 }
988 }
989 }
990
getFastMathFlags(const MachineInstr & I)991 static unsigned getFastMathFlags(const MachineInstr &I) {
992 unsigned Flags = SPIRV::FPFastMathMode::None;
993 if (I.getFlag(MachineInstr::MIFlag::FmNoNans))
994 Flags |= SPIRV::FPFastMathMode::NotNaN;
995 if (I.getFlag(MachineInstr::MIFlag::FmNoInfs))
996 Flags |= SPIRV::FPFastMathMode::NotInf;
997 if (I.getFlag(MachineInstr::MIFlag::FmNsz))
998 Flags |= SPIRV::FPFastMathMode::NSZ;
999 if (I.getFlag(MachineInstr::MIFlag::FmArcp))
1000 Flags |= SPIRV::FPFastMathMode::AllowRecip;
1001 if (I.getFlag(MachineInstr::MIFlag::FmReassoc))
1002 Flags |= SPIRV::FPFastMathMode::Fast;
1003 return Flags;
1004 }
1005
handleMIFlagDecoration(MachineInstr & I,const SPIRVSubtarget & ST,const SPIRVInstrInfo & TII,SPIRV::RequirementHandler & Reqs)1006 static void handleMIFlagDecoration(MachineInstr &I, const SPIRVSubtarget &ST,
1007 const SPIRVInstrInfo &TII,
1008 SPIRV::RequirementHandler &Reqs) {
1009 if (I.getFlag(MachineInstr::MIFlag::NoSWrap) && TII.canUseNSW(I) &&
1010 getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand,
1011 SPIRV::Decoration::NoSignedWrap, ST, Reqs)
1012 .IsSatisfiable) {
1013 buildOpDecorate(I.getOperand(0).getReg(), I, TII,
1014 SPIRV::Decoration::NoSignedWrap, {});
1015 }
1016 if (I.getFlag(MachineInstr::MIFlag::NoUWrap) && TII.canUseNUW(I) &&
1017 getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand,
1018 SPIRV::Decoration::NoUnsignedWrap, ST,
1019 Reqs)
1020 .IsSatisfiable) {
1021 buildOpDecorate(I.getOperand(0).getReg(), I, TII,
1022 SPIRV::Decoration::NoUnsignedWrap, {});
1023 }
1024 if (!TII.canUseFastMathFlags(I))
1025 return;
1026 unsigned FMFlags = getFastMathFlags(I);
1027 if (FMFlags == SPIRV::FPFastMathMode::None)
1028 return;
1029 Register DstReg = I.getOperand(0).getReg();
1030 buildOpDecorate(DstReg, I, TII, SPIRV::Decoration::FPFastMathMode, {FMFlags});
1031 }
1032
1033 // Walk all functions and add decorations related to MI flags.
addDecorations(const Module & M,const SPIRVInstrInfo & TII,MachineModuleInfo * MMI,const SPIRVSubtarget & ST,SPIRV::ModuleAnalysisInfo & MAI)1034 static void addDecorations(const Module &M, const SPIRVInstrInfo &TII,
1035 MachineModuleInfo *MMI, const SPIRVSubtarget &ST,
1036 SPIRV::ModuleAnalysisInfo &MAI) {
1037 for (auto F = M.begin(), E = M.end(); F != E; ++F) {
1038 MachineFunction *MF = MMI->getMachineFunction(*F);
1039 if (!MF)
1040 continue;
1041 for (auto &MBB : *MF)
1042 for (auto &MI : MBB)
1043 handleMIFlagDecoration(MI, ST, TII, MAI.Reqs);
1044 }
1045 }
1046
1047 struct SPIRV::ModuleAnalysisInfo SPIRVModuleAnalysis::MAI;
1048
getAnalysisUsage(AnalysisUsage & AU) const1049 void SPIRVModuleAnalysis::getAnalysisUsage(AnalysisUsage &AU) const {
1050 AU.addRequired<TargetPassConfig>();
1051 AU.addRequired<MachineModuleInfoWrapperPass>();
1052 }
1053
runOnModule(Module & M)1054 bool SPIRVModuleAnalysis::runOnModule(Module &M) {
1055 SPIRVTargetMachine &TM =
1056 getAnalysis<TargetPassConfig>().getTM<SPIRVTargetMachine>();
1057 ST = TM.getSubtargetImpl();
1058 GR = ST->getSPIRVGlobalRegistry();
1059 TII = ST->getInstrInfo();
1060
1061 MMI = &getAnalysis<MachineModuleInfoWrapperPass>().getMMI();
1062
1063 setBaseInfo(M);
1064
1065 addDecorations(M, *TII, MMI, *ST, MAI);
1066
1067 collectReqs(M, MAI, MMI, *ST);
1068
1069 // Process type/const/global var/func decl instructions, number their
1070 // destination registers from 0 to N, collect Extensions and Capabilities.
1071 processDefInstrs(M);
1072
1073 // Number rest of registers from N+1 onwards.
1074 numberRegistersGlobally(M);
1075
1076 // Collect OpName, OpEntryPoint, OpDecorate etc, process other instructions.
1077 processOtherInstrs(M);
1078
1079 // If there are no entry points, we need the Linkage capability.
1080 if (MAI.MS[SPIRV::MB_EntryPoints].empty())
1081 MAI.Reqs.addCapability(SPIRV::Capability::Linkage);
1082
1083 return false;
1084 }
1085