1 //===-- NVPTXAsmPrinter.cpp - NVPTX LLVM assembly writer ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains a printer that converts from our internal representation
10 // of machine-dependent LLVM code to NVPTX assembly language.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "NVPTXAsmPrinter.h"
15 #include "MCTargetDesc/NVPTXBaseInfo.h"
16 #include "MCTargetDesc/NVPTXInstPrinter.h"
17 #include "MCTargetDesc/NVPTXMCAsmInfo.h"
18 #include "MCTargetDesc/NVPTXTargetStreamer.h"
19 #include "NVPTX.h"
20 #include "NVPTXMCExpr.h"
21 #include "NVPTXMachineFunctionInfo.h"
22 #include "NVPTXRegisterInfo.h"
23 #include "NVPTXSubtarget.h"
24 #include "NVPTXTargetMachine.h"
25 #include "NVPTXUtilities.h"
26 #include "TargetInfo/NVPTXTargetInfo.h"
27 #include "cl_common_defines.h"
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/DenseSet.h"
32 #include "llvm/ADT/SmallString.h"
33 #include "llvm/ADT/SmallVector.h"
34 #include "llvm/ADT/StringExtras.h"
35 #include "llvm/ADT/StringRef.h"
36 #include "llvm/ADT/Triple.h"
37 #include "llvm/ADT/Twine.h"
38 #include "llvm/Analysis/ConstantFolding.h"
39 #include "llvm/CodeGen/Analysis.h"
40 #include "llvm/CodeGen/MachineBasicBlock.h"
41 #include "llvm/CodeGen/MachineFrameInfo.h"
42 #include "llvm/CodeGen/MachineFunction.h"
43 #include "llvm/CodeGen/MachineInstr.h"
44 #include "llvm/CodeGen/MachineLoopInfo.h"
45 #include "llvm/CodeGen/MachineModuleInfo.h"
46 #include "llvm/CodeGen/MachineOperand.h"
47 #include "llvm/CodeGen/MachineRegisterInfo.h"
48 #include "llvm/CodeGen/TargetRegisterInfo.h"
49 #include "llvm/CodeGen/ValueTypes.h"
50 #include "llvm/IR/Attributes.h"
51 #include "llvm/IR/BasicBlock.h"
52 #include "llvm/IR/Constant.h"
53 #include "llvm/IR/Constants.h"
54 #include "llvm/IR/DataLayout.h"
55 #include "llvm/IR/DebugInfo.h"
56 #include "llvm/IR/DebugInfoMetadata.h"
57 #include "llvm/IR/DebugLoc.h"
58 #include "llvm/IR/DerivedTypes.h"
59 #include "llvm/IR/Function.h"
60 #include "llvm/IR/GlobalValue.h"
61 #include "llvm/IR/GlobalVariable.h"
62 #include "llvm/IR/Instruction.h"
63 #include "llvm/IR/LLVMContext.h"
64 #include "llvm/IR/Module.h"
65 #include "llvm/IR/Operator.h"
66 #include "llvm/IR/Type.h"
67 #include "llvm/IR/User.h"
68 #include "llvm/MC/MCExpr.h"
69 #include "llvm/MC/MCInst.h"
70 #include "llvm/MC/MCInstrDesc.h"
71 #include "llvm/MC/MCStreamer.h"
72 #include "llvm/MC/MCSymbol.h"
73 #include "llvm/MC/TargetRegistry.h"
74 #include "llvm/Support/Casting.h"
75 #include "llvm/Support/CommandLine.h"
76 #include "llvm/Support/Endian.h"
77 #include "llvm/Support/ErrorHandling.h"
78 #include "llvm/Support/MachineValueType.h"
79 #include "llvm/Support/NativeFormatting.h"
80 #include "llvm/Support/Path.h"
81 #include "llvm/Support/raw_ostream.h"
82 #include "llvm/Target/TargetLoweringObjectFile.h"
83 #include "llvm/Target/TargetMachine.h"
84 #include "llvm/Transforms/Utils/UnrollLoop.h"
85 #include <cassert>
86 #include <cstdint>
87 #include <cstring>
88 #include <new>
89 #include <string>
90 #include <utility>
91 #include <vector>
92
93 using namespace llvm;
94
95 #define DEPOTNAME "__local_depot"
96
97 /// DiscoverDependentGlobals - Return a set of GlobalVariables on which \p V
98 /// depends.
99 static void
DiscoverDependentGlobals(const Value * V,DenseSet<const GlobalVariable * > & Globals)100 DiscoverDependentGlobals(const Value *V,
101 DenseSet<const GlobalVariable *> &Globals) {
102 if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V))
103 Globals.insert(GV);
104 else {
105 if (const User *U = dyn_cast<User>(V)) {
106 for (unsigned i = 0, e = U->getNumOperands(); i != e; ++i) {
107 DiscoverDependentGlobals(U->getOperand(i), Globals);
108 }
109 }
110 }
111 }
112
113 /// VisitGlobalVariableForEmission - Add \p GV to the list of GlobalVariable
114 /// instances to be emitted, but only after any dependents have been added
115 /// first.s
116 static void
VisitGlobalVariableForEmission(const GlobalVariable * GV,SmallVectorImpl<const GlobalVariable * > & Order,DenseSet<const GlobalVariable * > & Visited,DenseSet<const GlobalVariable * > & Visiting)117 VisitGlobalVariableForEmission(const GlobalVariable *GV,
118 SmallVectorImpl<const GlobalVariable *> &Order,
119 DenseSet<const GlobalVariable *> &Visited,
120 DenseSet<const GlobalVariable *> &Visiting) {
121 // Have we already visited this one?
122 if (Visited.count(GV))
123 return;
124
125 // Do we have a circular dependency?
126 if (!Visiting.insert(GV).second)
127 report_fatal_error("Circular dependency found in global variable set");
128
129 // Make sure we visit all dependents first
130 DenseSet<const GlobalVariable *> Others;
131 for (unsigned i = 0, e = GV->getNumOperands(); i != e; ++i)
132 DiscoverDependentGlobals(GV->getOperand(i), Others);
133
134 for (const GlobalVariable *GV : Others)
135 VisitGlobalVariableForEmission(GV, Order, Visited, Visiting);
136
137 // Now we can visit ourself
138 Order.push_back(GV);
139 Visited.insert(GV);
140 Visiting.erase(GV);
141 }
142
emitInstruction(const MachineInstr * MI)143 void NVPTXAsmPrinter::emitInstruction(const MachineInstr *MI) {
144 NVPTX_MC::verifyInstructionPredicates(MI->getOpcode(),
145 getSubtargetInfo().getFeatureBits());
146
147 MCInst Inst;
148 lowerToMCInst(MI, Inst);
149 EmitToStreamer(*OutStreamer, Inst);
150 }
151
152 // Handle symbol backtracking for targets that do not support image handles
lowerImageHandleOperand(const MachineInstr * MI,unsigned OpNo,MCOperand & MCOp)153 bool NVPTXAsmPrinter::lowerImageHandleOperand(const MachineInstr *MI,
154 unsigned OpNo, MCOperand &MCOp) {
155 const MachineOperand &MO = MI->getOperand(OpNo);
156 const MCInstrDesc &MCID = MI->getDesc();
157
158 if (MCID.TSFlags & NVPTXII::IsTexFlag) {
159 // This is a texture fetch, so operand 4 is a texref and operand 5 is
160 // a samplerref
161 if (OpNo == 4 && MO.isImm()) {
162 lowerImageHandleSymbol(MO.getImm(), MCOp);
163 return true;
164 }
165 if (OpNo == 5 && MO.isImm() && !(MCID.TSFlags & NVPTXII::IsTexModeUnifiedFlag)) {
166 lowerImageHandleSymbol(MO.getImm(), MCOp);
167 return true;
168 }
169
170 return false;
171 } else if (MCID.TSFlags & NVPTXII::IsSuldMask) {
172 unsigned VecSize =
173 1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1);
174
175 // For a surface load of vector size N, the Nth operand will be the surfref
176 if (OpNo == VecSize && MO.isImm()) {
177 lowerImageHandleSymbol(MO.getImm(), MCOp);
178 return true;
179 }
180
181 return false;
182 } else if (MCID.TSFlags & NVPTXII::IsSustFlag) {
183 // This is a surface store, so operand 0 is a surfref
184 if (OpNo == 0 && MO.isImm()) {
185 lowerImageHandleSymbol(MO.getImm(), MCOp);
186 return true;
187 }
188
189 return false;
190 } else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) {
191 // This is a query, so operand 1 is a surfref/texref
192 if (OpNo == 1 && MO.isImm()) {
193 lowerImageHandleSymbol(MO.getImm(), MCOp);
194 return true;
195 }
196
197 return false;
198 }
199
200 return false;
201 }
202
lowerImageHandleSymbol(unsigned Index,MCOperand & MCOp)203 void NVPTXAsmPrinter::lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp) {
204 // Ewwww
205 LLVMTargetMachine &TM = const_cast<LLVMTargetMachine&>(MF->getTarget());
206 NVPTXTargetMachine &nvTM = static_cast<NVPTXTargetMachine&>(TM);
207 const NVPTXMachineFunctionInfo *MFI = MF->getInfo<NVPTXMachineFunctionInfo>();
208 const char *Sym = MFI->getImageHandleSymbol(Index);
209 std::string *SymNamePtr =
210 nvTM.getManagedStrPool()->getManagedString(Sym);
211 MCOp = GetSymbolRef(OutContext.getOrCreateSymbol(StringRef(*SymNamePtr)));
212 }
213
lowerToMCInst(const MachineInstr * MI,MCInst & OutMI)214 void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) {
215 OutMI.setOpcode(MI->getOpcode());
216 // Special: Do not mangle symbol operand of CALL_PROTOTYPE
217 if (MI->getOpcode() == NVPTX::CALL_PROTOTYPE) {
218 const MachineOperand &MO = MI->getOperand(0);
219 OutMI.addOperand(GetSymbolRef(
220 OutContext.getOrCreateSymbol(Twine(MO.getSymbolName()))));
221 return;
222 }
223
224 const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>();
225 for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) {
226 const MachineOperand &MO = MI->getOperand(i);
227
228 MCOperand MCOp;
229 if (!STI.hasImageHandles()) {
230 if (lowerImageHandleOperand(MI, i, MCOp)) {
231 OutMI.addOperand(MCOp);
232 continue;
233 }
234 }
235
236 if (lowerOperand(MO, MCOp))
237 OutMI.addOperand(MCOp);
238 }
239 }
240
lowerOperand(const MachineOperand & MO,MCOperand & MCOp)241 bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO,
242 MCOperand &MCOp) {
243 switch (MO.getType()) {
244 default: llvm_unreachable("unknown operand type");
245 case MachineOperand::MO_Register:
246 MCOp = MCOperand::createReg(encodeVirtualRegister(MO.getReg()));
247 break;
248 case MachineOperand::MO_Immediate:
249 MCOp = MCOperand::createImm(MO.getImm());
250 break;
251 case MachineOperand::MO_MachineBasicBlock:
252 MCOp = MCOperand::createExpr(MCSymbolRefExpr::create(
253 MO.getMBB()->getSymbol(), OutContext));
254 break;
255 case MachineOperand::MO_ExternalSymbol:
256 MCOp = GetSymbolRef(GetExternalSymbolSymbol(MO.getSymbolName()));
257 break;
258 case MachineOperand::MO_GlobalAddress:
259 MCOp = GetSymbolRef(getSymbol(MO.getGlobal()));
260 break;
261 case MachineOperand::MO_FPImmediate: {
262 const ConstantFP *Cnt = MO.getFPImm();
263 const APFloat &Val = Cnt->getValueAPF();
264
265 switch (Cnt->getType()->getTypeID()) {
266 default: report_fatal_error("Unsupported FP type"); break;
267 case Type::HalfTyID:
268 MCOp = MCOperand::createExpr(
269 NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext));
270 break;
271 case Type::FloatTyID:
272 MCOp = MCOperand::createExpr(
273 NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext));
274 break;
275 case Type::DoubleTyID:
276 MCOp = MCOperand::createExpr(
277 NVPTXFloatMCExpr::createConstantFPDouble(Val, OutContext));
278 break;
279 }
280 break;
281 }
282 }
283 return true;
284 }
285
encodeVirtualRegister(unsigned Reg)286 unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {
287 if (Register::isVirtualRegister(Reg)) {
288 const TargetRegisterClass *RC = MRI->getRegClass(Reg);
289
290 DenseMap<unsigned, unsigned> &RegMap = VRegMapping[RC];
291 unsigned RegNum = RegMap[Reg];
292
293 // Encode the register class in the upper 4 bits
294 // Must be kept in sync with NVPTXInstPrinter::printRegName
295 unsigned Ret = 0;
296 if (RC == &NVPTX::Int1RegsRegClass) {
297 Ret = (1 << 28);
298 } else if (RC == &NVPTX::Int16RegsRegClass) {
299 Ret = (2 << 28);
300 } else if (RC == &NVPTX::Int32RegsRegClass) {
301 Ret = (3 << 28);
302 } else if (RC == &NVPTX::Int64RegsRegClass) {
303 Ret = (4 << 28);
304 } else if (RC == &NVPTX::Float32RegsRegClass) {
305 Ret = (5 << 28);
306 } else if (RC == &NVPTX::Float64RegsRegClass) {
307 Ret = (6 << 28);
308 } else if (RC == &NVPTX::Float16RegsRegClass) {
309 Ret = (7 << 28);
310 } else if (RC == &NVPTX::Float16x2RegsRegClass) {
311 Ret = (8 << 28);
312 } else {
313 report_fatal_error("Bad register class");
314 }
315
316 // Insert the vreg number
317 Ret |= (RegNum & 0x0FFFFFFF);
318 return Ret;
319 } else {
320 // Some special-use registers are actually physical registers.
321 // Encode this as the register class ID of 0 and the real register ID.
322 return Reg & 0x0FFFFFFF;
323 }
324 }
325
GetSymbolRef(const MCSymbol * Symbol)326 MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) {
327 const MCExpr *Expr;
328 Expr = MCSymbolRefExpr::create(Symbol, MCSymbolRefExpr::VK_None,
329 OutContext);
330 return MCOperand::createExpr(Expr);
331 }
332
printReturnValStr(const Function * F,raw_ostream & O)333 void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
334 const DataLayout &DL = getDataLayout();
335 const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
336 const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering());
337
338 Type *Ty = F->getReturnType();
339
340 bool isABI = (STI.getSmVersion() >= 20);
341
342 if (Ty->getTypeID() == Type::VoidTyID)
343 return;
344
345 O << " (";
346
347 if (isABI) {
348 if (Ty->isFloatingPointTy() || (Ty->isIntegerTy() && !Ty->isIntegerTy(128))) {
349 unsigned size = 0;
350 if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
351 size = ITy->getBitWidth();
352 } else {
353 assert(Ty->isFloatingPointTy() && "Floating point type expected here");
354 size = Ty->getPrimitiveSizeInBits();
355 }
356 // PTX ABI requires all scalar return values to be at least 32
357 // bits in size. fp16 normally uses .b16 as its storage type in
358 // PTX, so its size must be adjusted here, too.
359 size = promoteScalarArgumentSize(size);
360
361 O << ".param .b" << size << " func_retval0";
362 } else if (isa<PointerType>(Ty)) {
363 O << ".param .b" << TLI->getPointerTy(DL).getSizeInBits()
364 << " func_retval0";
365 } else if (Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128)) {
366 unsigned totalsz = DL.getTypeAllocSize(Ty);
367 unsigned retAlignment = 0;
368 if (!getAlign(*F, 0, retAlignment))
369 retAlignment = TLI->getFunctionParamOptimizedAlign(F, Ty, DL).value();
370 O << ".param .align " << retAlignment << " .b8 func_retval0[" << totalsz
371 << "]";
372 } else
373 llvm_unreachable("Unknown return type");
374 } else {
375 SmallVector<EVT, 16> vtparts;
376 ComputeValueVTs(*TLI, DL, Ty, vtparts);
377 unsigned idx = 0;
378 for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
379 unsigned elems = 1;
380 EVT elemtype = vtparts[i];
381 if (vtparts[i].isVector()) {
382 elems = vtparts[i].getVectorNumElements();
383 elemtype = vtparts[i].getVectorElementType();
384 }
385
386 for (unsigned j = 0, je = elems; j != je; ++j) {
387 unsigned sz = elemtype.getSizeInBits();
388 if (elemtype.isInteger())
389 sz = promoteScalarArgumentSize(sz);
390 O << ".reg .b" << sz << " func_retval" << idx;
391 if (j < je - 1)
392 O << ", ";
393 ++idx;
394 }
395 if (i < e - 1)
396 O << ", ";
397 }
398 }
399 O << ") ";
400 }
401
printReturnValStr(const MachineFunction & MF,raw_ostream & O)402 void NVPTXAsmPrinter::printReturnValStr(const MachineFunction &MF,
403 raw_ostream &O) {
404 const Function &F = MF.getFunction();
405 printReturnValStr(&F, O);
406 }
407
408 // Return true if MBB is the header of a loop marked with
409 // llvm.loop.unroll.disable.
410 // TODO: consider "#pragma unroll 1" which is equivalent to "#pragma nounroll".
isLoopHeaderOfNoUnroll(const MachineBasicBlock & MBB) const411 bool NVPTXAsmPrinter::isLoopHeaderOfNoUnroll(
412 const MachineBasicBlock &MBB) const {
413 MachineLoopInfo &LI = getAnalysis<MachineLoopInfo>();
414 // We insert .pragma "nounroll" only to the loop header.
415 if (!LI.isLoopHeader(&MBB))
416 return false;
417
418 // llvm.loop.unroll.disable is marked on the back edges of a loop. Therefore,
419 // we iterate through each back edge of the loop with header MBB, and check
420 // whether its metadata contains llvm.loop.unroll.disable.
421 for (const MachineBasicBlock *PMBB : MBB.predecessors()) {
422 if (LI.getLoopFor(PMBB) != LI.getLoopFor(&MBB)) {
423 // Edges from other loops to MBB are not back edges.
424 continue;
425 }
426 if (const BasicBlock *PBB = PMBB->getBasicBlock()) {
427 if (MDNode *LoopID =
428 PBB->getTerminator()->getMetadata(LLVMContext::MD_loop)) {
429 if (GetUnrollMetadata(LoopID, "llvm.loop.unroll.disable"))
430 return true;
431 }
432 }
433 }
434 return false;
435 }
436
emitBasicBlockStart(const MachineBasicBlock & MBB)437 void NVPTXAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) {
438 AsmPrinter::emitBasicBlockStart(MBB);
439 if (isLoopHeaderOfNoUnroll(MBB))
440 OutStreamer->emitRawText(StringRef("\t.pragma \"nounroll\";\n"));
441 }
442
emitFunctionEntryLabel()443 void NVPTXAsmPrinter::emitFunctionEntryLabel() {
444 SmallString<128> Str;
445 raw_svector_ostream O(Str);
446
447 if (!GlobalsEmitted) {
448 emitGlobals(*MF->getFunction().getParent());
449 GlobalsEmitted = true;
450 }
451
452 // Set up
453 MRI = &MF->getRegInfo();
454 F = &MF->getFunction();
455 emitLinkageDirective(F, O);
456 if (isKernelFunction(*F))
457 O << ".entry ";
458 else {
459 O << ".func ";
460 printReturnValStr(*MF, O);
461 }
462
463 CurrentFnSym->print(O, MAI);
464
465 emitFunctionParamList(*MF, O);
466
467 if (isKernelFunction(*F))
468 emitKernelFunctionDirectives(*F, O);
469
470 OutStreamer->emitRawText(O.str());
471
472 VRegMapping.clear();
473 // Emit open brace for function body.
474 OutStreamer->emitRawText(StringRef("{\n"));
475 setAndEmitFunctionVirtualRegisters(*MF);
476 // Emit initial .loc debug directive for correct relocation symbol data.
477 if (MMI && MMI->hasDebugInfo())
478 emitInitialRawDwarfLocDirective(*MF);
479 }
480
runOnMachineFunction(MachineFunction & F)481 bool NVPTXAsmPrinter::runOnMachineFunction(MachineFunction &F) {
482 bool Result = AsmPrinter::runOnMachineFunction(F);
483 // Emit closing brace for the body of function F.
484 // The closing brace must be emitted here because we need to emit additional
485 // debug labels/data after the last basic block.
486 // We need to emit the closing brace here because we don't have function that
487 // finished emission of the function body.
488 OutStreamer->emitRawText(StringRef("}\n"));
489 return Result;
490 }
491
emitFunctionBodyStart()492 void NVPTXAsmPrinter::emitFunctionBodyStart() {
493 SmallString<128> Str;
494 raw_svector_ostream O(Str);
495 emitDemotedVars(&MF->getFunction(), O);
496 OutStreamer->emitRawText(O.str());
497 }
498
emitFunctionBodyEnd()499 void NVPTXAsmPrinter::emitFunctionBodyEnd() {
500 VRegMapping.clear();
501 }
502
getFunctionFrameSymbol() const503 const MCSymbol *NVPTXAsmPrinter::getFunctionFrameSymbol() const {
504 SmallString<128> Str;
505 raw_svector_ostream(Str) << DEPOTNAME << getFunctionNumber();
506 return OutContext.getOrCreateSymbol(Str);
507 }
508
emitImplicitDef(const MachineInstr * MI) const509 void NVPTXAsmPrinter::emitImplicitDef(const MachineInstr *MI) const {
510 Register RegNo = MI->getOperand(0).getReg();
511 if (Register::isVirtualRegister(RegNo)) {
512 OutStreamer->AddComment(Twine("implicit-def: ") +
513 getVirtualRegisterName(RegNo));
514 } else {
515 const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>();
516 OutStreamer->AddComment(Twine("implicit-def: ") +
517 STI.getRegisterInfo()->getName(RegNo));
518 }
519 OutStreamer->addBlankLine();
520 }
521
emitKernelFunctionDirectives(const Function & F,raw_ostream & O) const522 void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F,
523 raw_ostream &O) const {
524 // If the NVVM IR has some of reqntid* specified, then output
525 // the reqntid directive, and set the unspecified ones to 1.
526 // If none of reqntid* is specified, don't output reqntid directive.
527 unsigned reqntidx, reqntidy, reqntidz;
528 bool specified = false;
529 if (!getReqNTIDx(F, reqntidx))
530 reqntidx = 1;
531 else
532 specified = true;
533 if (!getReqNTIDy(F, reqntidy))
534 reqntidy = 1;
535 else
536 specified = true;
537 if (!getReqNTIDz(F, reqntidz))
538 reqntidz = 1;
539 else
540 specified = true;
541
542 if (specified)
543 O << ".reqntid " << reqntidx << ", " << reqntidy << ", " << reqntidz
544 << "\n";
545
546 // If the NVVM IR has some of maxntid* specified, then output
547 // the maxntid directive, and set the unspecified ones to 1.
548 // If none of maxntid* is specified, don't output maxntid directive.
549 unsigned maxntidx, maxntidy, maxntidz;
550 specified = false;
551 if (!getMaxNTIDx(F, maxntidx))
552 maxntidx = 1;
553 else
554 specified = true;
555 if (!getMaxNTIDy(F, maxntidy))
556 maxntidy = 1;
557 else
558 specified = true;
559 if (!getMaxNTIDz(F, maxntidz))
560 maxntidz = 1;
561 else
562 specified = true;
563
564 if (specified)
565 O << ".maxntid " << maxntidx << ", " << maxntidy << ", " << maxntidz
566 << "\n";
567
568 unsigned mincta;
569 if (getMinCTASm(F, mincta))
570 O << ".minnctapersm " << mincta << "\n";
571
572 unsigned maxnreg;
573 if (getMaxNReg(F, maxnreg))
574 O << ".maxnreg " << maxnreg << "\n";
575 }
576
577 std::string
getVirtualRegisterName(unsigned Reg) const578 NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const {
579 const TargetRegisterClass *RC = MRI->getRegClass(Reg);
580
581 std::string Name;
582 raw_string_ostream NameStr(Name);
583
584 VRegRCMap::const_iterator I = VRegMapping.find(RC);
585 assert(I != VRegMapping.end() && "Bad register class");
586 const DenseMap<unsigned, unsigned> &RegMap = I->second;
587
588 VRegMap::const_iterator VI = RegMap.find(Reg);
589 assert(VI != RegMap.end() && "Bad virtual register");
590 unsigned MappedVR = VI->second;
591
592 NameStr << getNVPTXRegClassStr(RC) << MappedVR;
593
594 NameStr.flush();
595 return Name;
596 }
597
emitVirtualRegister(unsigned int vr,raw_ostream & O)598 void NVPTXAsmPrinter::emitVirtualRegister(unsigned int vr,
599 raw_ostream &O) {
600 O << getVirtualRegisterName(vr);
601 }
602
emitDeclaration(const Function * F,raw_ostream & O)603 void NVPTXAsmPrinter::emitDeclaration(const Function *F, raw_ostream &O) {
604 emitLinkageDirective(F, O);
605 if (isKernelFunction(*F))
606 O << ".entry ";
607 else
608 O << ".func ";
609 printReturnValStr(F, O);
610 getSymbol(F)->print(O, MAI);
611 O << "\n";
612 emitFunctionParamList(F, O);
613 O << ";\n";
614 }
615
usedInGlobalVarDef(const Constant * C)616 static bool usedInGlobalVarDef(const Constant *C) {
617 if (!C)
618 return false;
619
620 if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(C)) {
621 return GV->getName() != "llvm.used";
622 }
623
624 for (const User *U : C->users())
625 if (const Constant *C = dyn_cast<Constant>(U))
626 if (usedInGlobalVarDef(C))
627 return true;
628
629 return false;
630 }
631
usedInOneFunc(const User * U,Function const * & oneFunc)632 static bool usedInOneFunc(const User *U, Function const *&oneFunc) {
633 if (const GlobalVariable *othergv = dyn_cast<GlobalVariable>(U)) {
634 if (othergv->getName() == "llvm.used")
635 return true;
636 }
637
638 if (const Instruction *instr = dyn_cast<Instruction>(U)) {
639 if (instr->getParent() && instr->getParent()->getParent()) {
640 const Function *curFunc = instr->getParent()->getParent();
641 if (oneFunc && (curFunc != oneFunc))
642 return false;
643 oneFunc = curFunc;
644 return true;
645 } else
646 return false;
647 }
648
649 for (const User *UU : U->users())
650 if (!usedInOneFunc(UU, oneFunc))
651 return false;
652
653 return true;
654 }
655
656 /* Find out if a global variable can be demoted to local scope.
657 * Currently, this is valid for CUDA shared variables, which have local
658 * scope and global lifetime. So the conditions to check are :
659 * 1. Is the global variable in shared address space?
660 * 2. Does it have internal linkage?
661 * 3. Is the global variable referenced only in one function?
662 */
canDemoteGlobalVar(const GlobalVariable * gv,Function const * & f)663 static bool canDemoteGlobalVar(const GlobalVariable *gv, Function const *&f) {
664 if (!gv->hasInternalLinkage())
665 return false;
666 PointerType *Pty = gv->getType();
667 if (Pty->getAddressSpace() != ADDRESS_SPACE_SHARED)
668 return false;
669
670 const Function *oneFunc = nullptr;
671
672 bool flag = usedInOneFunc(gv, oneFunc);
673 if (!flag)
674 return false;
675 if (!oneFunc)
676 return false;
677 f = oneFunc;
678 return true;
679 }
680
useFuncSeen(const Constant * C,DenseMap<const Function *,bool> & seenMap)681 static bool useFuncSeen(const Constant *C,
682 DenseMap<const Function *, bool> &seenMap) {
683 for (const User *U : C->users()) {
684 if (const Constant *cu = dyn_cast<Constant>(U)) {
685 if (useFuncSeen(cu, seenMap))
686 return true;
687 } else if (const Instruction *I = dyn_cast<Instruction>(U)) {
688 const BasicBlock *bb = I->getParent();
689 if (!bb)
690 continue;
691 const Function *caller = bb->getParent();
692 if (!caller)
693 continue;
694 if (seenMap.find(caller) != seenMap.end())
695 return true;
696 }
697 }
698 return false;
699 }
700
emitDeclarations(const Module & M,raw_ostream & O)701 void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) {
702 DenseMap<const Function *, bool> seenMap;
703 for (const Function &F : M) {
704 if (F.getAttributes().hasFnAttr("nvptx-libcall-callee")) {
705 emitDeclaration(&F, O);
706 continue;
707 }
708
709 if (F.isDeclaration()) {
710 if (F.use_empty())
711 continue;
712 if (F.getIntrinsicID())
713 continue;
714 emitDeclaration(&F, O);
715 continue;
716 }
717 for (const User *U : F.users()) {
718 if (const Constant *C = dyn_cast<Constant>(U)) {
719 if (usedInGlobalVarDef(C)) {
720 // The use is in the initialization of a global variable
721 // that is a function pointer, so print a declaration
722 // for the original function
723 emitDeclaration(&F, O);
724 break;
725 }
726 // Emit a declaration of this function if the function that
727 // uses this constant expr has already been seen.
728 if (useFuncSeen(C, seenMap)) {
729 emitDeclaration(&F, O);
730 break;
731 }
732 }
733
734 if (!isa<Instruction>(U))
735 continue;
736 const Instruction *instr = cast<Instruction>(U);
737 const BasicBlock *bb = instr->getParent();
738 if (!bb)
739 continue;
740 const Function *caller = bb->getParent();
741 if (!caller)
742 continue;
743
744 // If a caller has already been seen, then the caller is
745 // appearing in the module before the callee. so print out
746 // a declaration for the callee.
747 if (seenMap.find(caller) != seenMap.end()) {
748 emitDeclaration(&F, O);
749 break;
750 }
751 }
752 seenMap[&F] = true;
753 }
754 }
755
isEmptyXXStructor(GlobalVariable * GV)756 static bool isEmptyXXStructor(GlobalVariable *GV) {
757 if (!GV) return true;
758 const ConstantArray *InitList = dyn_cast<ConstantArray>(GV->getInitializer());
759 if (!InitList) return true; // Not an array; we don't know how to parse.
760 return InitList->getNumOperands() == 0;
761 }
762
emitStartOfAsmFile(Module & M)763 void NVPTXAsmPrinter::emitStartOfAsmFile(Module &M) {
764 // Construct a default subtarget off of the TargetMachine defaults. The
765 // rest of NVPTX isn't friendly to change subtargets per function and
766 // so the default TargetMachine will have all of the options.
767 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
768 const auto* STI = static_cast<const NVPTXSubtarget*>(NTM.getSubtargetImpl());
769 SmallString<128> Str1;
770 raw_svector_ostream OS1(Str1);
771
772 // Emit header before any dwarf directives are emitted below.
773 emitHeader(M, OS1, *STI);
774 OutStreamer->emitRawText(OS1.str());
775 }
776
doInitialization(Module & M)777 bool NVPTXAsmPrinter::doInitialization(Module &M) {
778 if (M.alias_size()) {
779 report_fatal_error("Module has aliases, which NVPTX does not support.");
780 return true; // error
781 }
782 if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_ctors"))) {
783 report_fatal_error(
784 "Module has a nontrivial global ctor, which NVPTX does not support.");
785 return true; // error
786 }
787 if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_dtors"))) {
788 report_fatal_error(
789 "Module has a nontrivial global dtor, which NVPTX does not support.");
790 return true; // error
791 }
792
793 // We need to call the parent's one explicitly.
794 bool Result = AsmPrinter::doInitialization(M);
795
796 GlobalsEmitted = false;
797
798 return Result;
799 }
800
emitGlobals(const Module & M)801 void NVPTXAsmPrinter::emitGlobals(const Module &M) {
802 SmallString<128> Str2;
803 raw_svector_ostream OS2(Str2);
804
805 emitDeclarations(M, OS2);
806
807 // As ptxas does not support forward references of globals, we need to first
808 // sort the list of module-level globals in def-use order. We visit each
809 // global variable in order, and ensure that we emit it *after* its dependent
810 // globals. We use a little extra memory maintaining both a set and a list to
811 // have fast searches while maintaining a strict ordering.
812 SmallVector<const GlobalVariable *, 8> Globals;
813 DenseSet<const GlobalVariable *> GVVisited;
814 DenseSet<const GlobalVariable *> GVVisiting;
815
816 // Visit each global variable, in order
817 for (const GlobalVariable &I : M.globals())
818 VisitGlobalVariableForEmission(&I, Globals, GVVisited, GVVisiting);
819
820 assert(GVVisited.size() == M.getGlobalList().size() &&
821 "Missed a global variable");
822 assert(GVVisiting.size() == 0 && "Did not fully process a global variable");
823
824 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
825 const NVPTXSubtarget &STI =
826 *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
827
828 // Print out module-level global variables in proper order
829 for (unsigned i = 0, e = Globals.size(); i != e; ++i)
830 printModuleLevelGV(Globals[i], OS2, /*processDemoted=*/false, STI);
831
832 OS2 << '\n';
833
834 OutStreamer->emitRawText(OS2.str());
835 }
836
emitHeader(Module & M,raw_ostream & O,const NVPTXSubtarget & STI)837 void NVPTXAsmPrinter::emitHeader(Module &M, raw_ostream &O,
838 const NVPTXSubtarget &STI) {
839 O << "//\n";
840 O << "// Generated by LLVM NVPTX Back-End\n";
841 O << "//\n";
842 O << "\n";
843
844 unsigned PTXVersion = STI.getPTXVersion();
845 O << ".version " << (PTXVersion / 10) << "." << (PTXVersion % 10) << "\n";
846
847 O << ".target ";
848 O << STI.getTargetName();
849
850 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
851 if (NTM.getDrvInterface() == NVPTX::NVCL)
852 O << ", texmode_independent";
853
854 bool HasFullDebugInfo = false;
855 for (DICompileUnit *CU : M.debug_compile_units()) {
856 switch(CU->getEmissionKind()) {
857 case DICompileUnit::NoDebug:
858 case DICompileUnit::DebugDirectivesOnly:
859 break;
860 case DICompileUnit::LineTablesOnly:
861 case DICompileUnit::FullDebug:
862 HasFullDebugInfo = true;
863 break;
864 }
865 if (HasFullDebugInfo)
866 break;
867 }
868 if (MMI && MMI->hasDebugInfo() && HasFullDebugInfo)
869 O << ", debug";
870
871 O << "\n";
872
873 O << ".address_size ";
874 if (NTM.is64Bit())
875 O << "64";
876 else
877 O << "32";
878 O << "\n";
879
880 O << "\n";
881 }
882
doFinalization(Module & M)883 bool NVPTXAsmPrinter::doFinalization(Module &M) {
884 bool HasDebugInfo = MMI && MMI->hasDebugInfo();
885
886 // If we did not emit any functions, then the global declarations have not
887 // yet been emitted.
888 if (!GlobalsEmitted) {
889 emitGlobals(M);
890 GlobalsEmitted = true;
891 }
892
893 // call doFinalization
894 bool ret = AsmPrinter::doFinalization(M);
895
896 clearAnnotationCache(&M);
897
898 if (auto *TS = static_cast<NVPTXTargetStreamer *>(
899 OutStreamer->getTargetStreamer())) {
900 // Close the last emitted section
901 if (HasDebugInfo) {
902 TS->closeLastSection();
903 // Emit empty .debug_loc section for better support of the empty files.
904 OutStreamer->emitRawText("\t.section\t.debug_loc\t{\t}");
905 }
906
907 // Output last DWARF .file directives, if any.
908 TS->outputDwarfFileDirectives();
909 }
910
911 return ret;
912
913 //bool Result = AsmPrinter::doFinalization(M);
914 // Instead of calling the parents doFinalization, we may
915 // clone parents doFinalization and customize here.
916 // Currently, we if NVISA out the EmitGlobals() in
917 // parent's doFinalization, which is too intrusive.
918 //
919 // Same for the doInitialization.
920 //return Result;
921 }
922
923 // This function emits appropriate linkage directives for
924 // functions and global variables.
925 //
926 // extern function declaration -> .extern
927 // extern function definition -> .visible
928 // external global variable with init -> .visible
929 // external without init -> .extern
930 // appending -> not allowed, assert.
931 // for any linkage other than
932 // internal, private, linker_private,
933 // linker_private_weak, linker_private_weak_def_auto,
934 // we emit -> .weak.
935
emitLinkageDirective(const GlobalValue * V,raw_ostream & O)936 void NVPTXAsmPrinter::emitLinkageDirective(const GlobalValue *V,
937 raw_ostream &O) {
938 if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() == NVPTX::CUDA) {
939 if (V->hasExternalLinkage()) {
940 if (isa<GlobalVariable>(V)) {
941 const GlobalVariable *GVar = cast<GlobalVariable>(V);
942 if (GVar) {
943 if (GVar->hasInitializer())
944 O << ".visible ";
945 else
946 O << ".extern ";
947 }
948 } else if (V->isDeclaration())
949 O << ".extern ";
950 else
951 O << ".visible ";
952 } else if (V->hasAppendingLinkage()) {
953 std::string msg;
954 msg.append("Error: ");
955 msg.append("Symbol ");
956 if (V->hasName())
957 msg.append(std::string(V->getName()));
958 msg.append("has unsupported appending linkage type");
959 llvm_unreachable(msg.c_str());
960 } else if (!V->hasInternalLinkage() &&
961 !V->hasPrivateLinkage()) {
962 O << ".weak ";
963 }
964 }
965 }
966
printModuleLevelGV(const GlobalVariable * GVar,raw_ostream & O,bool processDemoted,const NVPTXSubtarget & STI)967 void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
968 raw_ostream &O, bool processDemoted,
969 const NVPTXSubtarget &STI) {
970 // Skip meta data
971 if (GVar->hasSection()) {
972 if (GVar->getSection() == "llvm.metadata")
973 return;
974 }
975
976 // Skip LLVM intrinsic global variables
977 if (GVar->getName().startswith("llvm.") ||
978 GVar->getName().startswith("nvvm."))
979 return;
980
981 const DataLayout &DL = getDataLayout();
982
983 // GlobalVariables are always constant pointers themselves.
984 PointerType *PTy = GVar->getType();
985 Type *ETy = GVar->getValueType();
986
987 if (GVar->hasExternalLinkage()) {
988 if (GVar->hasInitializer())
989 O << ".visible ";
990 else
991 O << ".extern ";
992 } else if (GVar->hasLinkOnceLinkage() || GVar->hasWeakLinkage() ||
993 GVar->hasAvailableExternallyLinkage() ||
994 GVar->hasCommonLinkage()) {
995 O << ".weak ";
996 }
997
998 if (isTexture(*GVar)) {
999 O << ".global .texref " << getTextureName(*GVar) << ";\n";
1000 return;
1001 }
1002
1003 if (isSurface(*GVar)) {
1004 O << ".global .surfref " << getSurfaceName(*GVar) << ";\n";
1005 return;
1006 }
1007
1008 if (GVar->isDeclaration()) {
1009 // (extern) declarations, no definition or initializer
1010 // Currently the only known declaration is for an automatic __local
1011 // (.shared) promoted to global.
1012 emitPTXGlobalVariable(GVar, O, STI);
1013 O << ";\n";
1014 return;
1015 }
1016
1017 if (isSampler(*GVar)) {
1018 O << ".global .samplerref " << getSamplerName(*GVar);
1019
1020 const Constant *Initializer = nullptr;
1021 if (GVar->hasInitializer())
1022 Initializer = GVar->getInitializer();
1023 const ConstantInt *CI = nullptr;
1024 if (Initializer)
1025 CI = dyn_cast<ConstantInt>(Initializer);
1026 if (CI) {
1027 unsigned sample = CI->getZExtValue();
1028
1029 O << " = { ";
1030
1031 for (int i = 0,
1032 addr = ((sample & __CLK_ADDRESS_MASK) >> __CLK_ADDRESS_BASE);
1033 i < 3; i++) {
1034 O << "addr_mode_" << i << " = ";
1035 switch (addr) {
1036 case 0:
1037 O << "wrap";
1038 break;
1039 case 1:
1040 O << "clamp_to_border";
1041 break;
1042 case 2:
1043 O << "clamp_to_edge";
1044 break;
1045 case 3:
1046 O << "wrap";
1047 break;
1048 case 4:
1049 O << "mirror";
1050 break;
1051 }
1052 O << ", ";
1053 }
1054 O << "filter_mode = ";
1055 switch ((sample & __CLK_FILTER_MASK) >> __CLK_FILTER_BASE) {
1056 case 0:
1057 O << "nearest";
1058 break;
1059 case 1:
1060 O << "linear";
1061 break;
1062 case 2:
1063 llvm_unreachable("Anisotropic filtering is not supported");
1064 default:
1065 O << "nearest";
1066 break;
1067 }
1068 if (!((sample & __CLK_NORMALIZED_MASK) >> __CLK_NORMALIZED_BASE)) {
1069 O << ", force_unnormalized_coords = 1";
1070 }
1071 O << " }";
1072 }
1073
1074 O << ";\n";
1075 return;
1076 }
1077
1078 if (GVar->hasPrivateLinkage()) {
1079 if (strncmp(GVar->getName().data(), "unrollpragma", 12) == 0)
1080 return;
1081
1082 // FIXME - need better way (e.g. Metadata) to avoid generating this global
1083 if (strncmp(GVar->getName().data(), "filename", 8) == 0)
1084 return;
1085 if (GVar->use_empty())
1086 return;
1087 }
1088
1089 const Function *demotedFunc = nullptr;
1090 if (!processDemoted && canDemoteGlobalVar(GVar, demotedFunc)) {
1091 O << "// " << GVar->getName() << " has been demoted\n";
1092 if (localDecls.find(demotedFunc) != localDecls.end())
1093 localDecls[demotedFunc].push_back(GVar);
1094 else {
1095 std::vector<const GlobalVariable *> temp;
1096 temp.push_back(GVar);
1097 localDecls[demotedFunc] = temp;
1098 }
1099 return;
1100 }
1101
1102 O << ".";
1103 emitPTXAddressSpace(PTy->getAddressSpace(), O);
1104
1105 if (isManaged(*GVar)) {
1106 if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) {
1107 report_fatal_error(
1108 ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
1109 }
1110 O << " .attribute(.managed)";
1111 }
1112
1113 if (MaybeAlign A = GVar->getAlign())
1114 O << " .align " << A->value();
1115 else
1116 O << " .align " << (int)DL.getPrefTypeAlignment(ETy);
1117
1118 if (ETy->isFloatingPointTy() || ETy->isPointerTy() ||
1119 (ETy->isIntegerTy() && ETy->getScalarSizeInBits() <= 64)) {
1120 O << " .";
1121 // Special case: ABI requires that we use .u8 for predicates
1122 if (ETy->isIntegerTy(1))
1123 O << "u8";
1124 else
1125 O << getPTXFundamentalTypeStr(ETy, false);
1126 O << " ";
1127 getSymbol(GVar)->print(O, MAI);
1128
1129 // Ptx allows variable initilization only for constant and global state
1130 // spaces.
1131 if (GVar->hasInitializer()) {
1132 if ((PTy->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1133 (PTy->getAddressSpace() == ADDRESS_SPACE_CONST)) {
1134 const Constant *Initializer = GVar->getInitializer();
1135 // 'undef' is treated as there is no value specified.
1136 if (!Initializer->isNullValue() && !isa<UndefValue>(Initializer)) {
1137 O << " = ";
1138 printScalarConstant(Initializer, O);
1139 }
1140 } else {
1141 // The frontend adds zero-initializer to device and constant variables
1142 // that don't have an initial value, and UndefValue to shared
1143 // variables, so skip warning for this case.
1144 if (!GVar->getInitializer()->isNullValue() &&
1145 !isa<UndefValue>(GVar->getInitializer())) {
1146 report_fatal_error("initial value of '" + GVar->getName() +
1147 "' is not allowed in addrspace(" +
1148 Twine(PTy->getAddressSpace()) + ")");
1149 }
1150 }
1151 }
1152 } else {
1153 unsigned int ElementSize = 0;
1154
1155 // Although PTX has direct support for struct type and array type and
1156 // LLVM IR is very similar to PTX, the LLVM CodeGen does not support for
1157 // targets that support these high level field accesses. Structs, arrays
1158 // and vectors are lowered into arrays of bytes.
1159 switch (ETy->getTypeID()) {
1160 case Type::IntegerTyID: // Integers larger than 64 bits
1161 case Type::StructTyID:
1162 case Type::ArrayTyID:
1163 case Type::FixedVectorTyID:
1164 ElementSize = DL.getTypeStoreSize(ETy);
1165 // Ptx allows variable initilization only for constant and
1166 // global state spaces.
1167 if (((PTy->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1168 (PTy->getAddressSpace() == ADDRESS_SPACE_CONST)) &&
1169 GVar->hasInitializer()) {
1170 const Constant *Initializer = GVar->getInitializer();
1171 if (!isa<UndefValue>(Initializer) && !Initializer->isNullValue()) {
1172 AggBuffer aggBuffer(ElementSize, *this);
1173 bufferAggregateConstant(Initializer, &aggBuffer);
1174 if (aggBuffer.numSymbols()) {
1175 unsigned int ptrSize = MAI->getCodePointerSize();
1176 if (ElementSize % ptrSize ||
1177 !aggBuffer.allSymbolsAligned(ptrSize)) {
1178 // Print in bytes and use the mask() operator for pointers.
1179 if (!STI.hasMaskOperator())
1180 report_fatal_error(
1181 "initialized packed aggregate with pointers '" +
1182 GVar->getName() +
1183 "' requires at least PTX ISA version 7.1");
1184 O << " .u8 ";
1185 getSymbol(GVar)->print(O, MAI);
1186 O << "[" << ElementSize << "] = {";
1187 aggBuffer.printBytes(O);
1188 O << "}";
1189 } else {
1190 O << " .u" << ptrSize * 8 << " ";
1191 getSymbol(GVar)->print(O, MAI);
1192 O << "[" << ElementSize / ptrSize << "] = {";
1193 aggBuffer.printWords(O);
1194 O << "}";
1195 }
1196 } else {
1197 O << " .b8 ";
1198 getSymbol(GVar)->print(O, MAI);
1199 O << "[" << ElementSize << "] = {";
1200 aggBuffer.printBytes(O);
1201 O << "}";
1202 }
1203 } else {
1204 O << " .b8 ";
1205 getSymbol(GVar)->print(O, MAI);
1206 if (ElementSize) {
1207 O << "[";
1208 O << ElementSize;
1209 O << "]";
1210 }
1211 }
1212 } else {
1213 O << " .b8 ";
1214 getSymbol(GVar)->print(O, MAI);
1215 if (ElementSize) {
1216 O << "[";
1217 O << ElementSize;
1218 O << "]";
1219 }
1220 }
1221 break;
1222 default:
1223 llvm_unreachable("type not supported yet");
1224 }
1225 }
1226 O << ";\n";
1227 }
1228
printSymbol(unsigned nSym,raw_ostream & os)1229 void NVPTXAsmPrinter::AggBuffer::printSymbol(unsigned nSym, raw_ostream &os) {
1230 const Value *v = Symbols[nSym];
1231 const Value *v0 = SymbolsBeforeStripping[nSym];
1232 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(v)) {
1233 MCSymbol *Name = AP.getSymbol(GVar);
1234 PointerType *PTy = dyn_cast<PointerType>(v0->getType());
1235 // Is v0 a generic pointer?
1236 bool isGenericPointer = PTy && PTy->getAddressSpace() == 0;
1237 if (EmitGeneric && isGenericPointer && !isa<Function>(v)) {
1238 os << "generic(";
1239 Name->print(os, AP.MAI);
1240 os << ")";
1241 } else {
1242 Name->print(os, AP.MAI);
1243 }
1244 } else if (const ConstantExpr *CExpr = dyn_cast<ConstantExpr>(v0)) {
1245 const MCExpr *Expr = AP.lowerConstantForGV(cast<Constant>(CExpr), false);
1246 AP.printMCExpr(*Expr, os);
1247 } else
1248 llvm_unreachable("symbol type unknown");
1249 }
1250
printBytes(raw_ostream & os)1251 void NVPTXAsmPrinter::AggBuffer::printBytes(raw_ostream &os) {
1252 unsigned int ptrSize = AP.MAI->getCodePointerSize();
1253 symbolPosInBuffer.push_back(size);
1254 unsigned int nSym = 0;
1255 unsigned int nextSymbolPos = symbolPosInBuffer[nSym];
1256 for (unsigned int pos = 0; pos < size;) {
1257 if (pos)
1258 os << ", ";
1259 if (pos != nextSymbolPos) {
1260 os << (unsigned int)buffer[pos];
1261 ++pos;
1262 continue;
1263 }
1264 // Generate a per-byte mask() operator for the symbol, which looks like:
1265 // .global .u8 addr[] = {0xFF(foo), 0xFF00(foo), 0xFF0000(foo), ...};
1266 // See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#initializers
1267 std::string symText;
1268 llvm::raw_string_ostream oss(symText);
1269 printSymbol(nSym, oss);
1270 for (unsigned i = 0; i < ptrSize; ++i) {
1271 if (i)
1272 os << ", ";
1273 llvm::write_hex(os, 0xFFULL << i * 8, HexPrintStyle::PrefixUpper);
1274 os << "(" << symText << ")";
1275 }
1276 pos += ptrSize;
1277 nextSymbolPos = symbolPosInBuffer[++nSym];
1278 assert(nextSymbolPos >= pos);
1279 }
1280 }
1281
printWords(raw_ostream & os)1282 void NVPTXAsmPrinter::AggBuffer::printWords(raw_ostream &os) {
1283 unsigned int ptrSize = AP.MAI->getCodePointerSize();
1284 symbolPosInBuffer.push_back(size);
1285 unsigned int nSym = 0;
1286 unsigned int nextSymbolPos = symbolPosInBuffer[nSym];
1287 assert(nextSymbolPos % ptrSize == 0);
1288 for (unsigned int pos = 0; pos < size; pos += ptrSize) {
1289 if (pos)
1290 os << ", ";
1291 if (pos == nextSymbolPos) {
1292 printSymbol(nSym, os);
1293 nextSymbolPos = symbolPosInBuffer[++nSym];
1294 assert(nextSymbolPos % ptrSize == 0);
1295 assert(nextSymbolPos >= pos + ptrSize);
1296 } else if (ptrSize == 4)
1297 os << support::endian::read32le(&buffer[pos]);
1298 else
1299 os << support::endian::read64le(&buffer[pos]);
1300 }
1301 }
1302
emitDemotedVars(const Function * f,raw_ostream & O)1303 void NVPTXAsmPrinter::emitDemotedVars(const Function *f, raw_ostream &O) {
1304 if (localDecls.find(f) == localDecls.end())
1305 return;
1306
1307 std::vector<const GlobalVariable *> &gvars = localDecls[f];
1308
1309 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
1310 const NVPTXSubtarget &STI =
1311 *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
1312
1313 for (const GlobalVariable *GV : gvars) {
1314 O << "\t// demoted variable\n\t";
1315 printModuleLevelGV(GV, O, /*processDemoted=*/true, STI);
1316 }
1317 }
1318
emitPTXAddressSpace(unsigned int AddressSpace,raw_ostream & O) const1319 void NVPTXAsmPrinter::emitPTXAddressSpace(unsigned int AddressSpace,
1320 raw_ostream &O) const {
1321 switch (AddressSpace) {
1322 case ADDRESS_SPACE_LOCAL:
1323 O << "local";
1324 break;
1325 case ADDRESS_SPACE_GLOBAL:
1326 O << "global";
1327 break;
1328 case ADDRESS_SPACE_CONST:
1329 O << "const";
1330 break;
1331 case ADDRESS_SPACE_SHARED:
1332 O << "shared";
1333 break;
1334 default:
1335 report_fatal_error("Bad address space found while emitting PTX: " +
1336 llvm::Twine(AddressSpace));
1337 break;
1338 }
1339 }
1340
1341 std::string
getPTXFundamentalTypeStr(Type * Ty,bool useB4PTR) const1342 NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const {
1343 switch (Ty->getTypeID()) {
1344 case Type::IntegerTyID: {
1345 unsigned NumBits = cast<IntegerType>(Ty)->getBitWidth();
1346 if (NumBits == 1)
1347 return "pred";
1348 else if (NumBits <= 64) {
1349 std::string name = "u";
1350 return name + utostr(NumBits);
1351 } else {
1352 llvm_unreachable("Integer too large");
1353 break;
1354 }
1355 break;
1356 }
1357 case Type::HalfTyID:
1358 // fp16 is stored as .b16 for compatibility with pre-sm_53 PTX assembly.
1359 return "b16";
1360 case Type::FloatTyID:
1361 return "f32";
1362 case Type::DoubleTyID:
1363 return "f64";
1364 case Type::PointerTyID:
1365 if (static_cast<const NVPTXTargetMachine &>(TM).is64Bit())
1366 if (useB4PTR)
1367 return "b64";
1368 else
1369 return "u64";
1370 else if (useB4PTR)
1371 return "b32";
1372 else
1373 return "u32";
1374 default:
1375 break;
1376 }
1377 llvm_unreachable("unexpected type");
1378 }
1379
emitPTXGlobalVariable(const GlobalVariable * GVar,raw_ostream & O,const NVPTXSubtarget & STI)1380 void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
1381 raw_ostream &O,
1382 const NVPTXSubtarget &STI) {
1383 const DataLayout &DL = getDataLayout();
1384
1385 // GlobalVariables are always constant pointers themselves.
1386 Type *ETy = GVar->getValueType();
1387
1388 O << ".";
1389 emitPTXAddressSpace(GVar->getType()->getAddressSpace(), O);
1390 if (isManaged(*GVar)) {
1391 if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) {
1392 report_fatal_error(
1393 ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
1394 }
1395 O << " .attribute(.managed)";
1396 }
1397 if (MaybeAlign A = GVar->getAlign())
1398 O << " .align " << A->value();
1399 else
1400 O << " .align " << (int)DL.getPrefTypeAlignment(ETy);
1401
1402 // Special case for i128
1403 if (ETy->isIntegerTy(128)) {
1404 O << " .b8 ";
1405 getSymbol(GVar)->print(O, MAI);
1406 O << "[16]";
1407 return;
1408 }
1409
1410 if (ETy->isFloatingPointTy() || ETy->isIntOrPtrTy()) {
1411 O << " .";
1412 O << getPTXFundamentalTypeStr(ETy);
1413 O << " ";
1414 getSymbol(GVar)->print(O, MAI);
1415 return;
1416 }
1417
1418 int64_t ElementSize = 0;
1419
1420 // Although PTX has direct support for struct type and array type and LLVM IR
1421 // is very similar to PTX, the LLVM CodeGen does not support for targets that
1422 // support these high level field accesses. Structs and arrays are lowered
1423 // into arrays of bytes.
1424 switch (ETy->getTypeID()) {
1425 case Type::StructTyID:
1426 case Type::ArrayTyID:
1427 case Type::FixedVectorTyID:
1428 ElementSize = DL.getTypeStoreSize(ETy);
1429 O << " .b8 ";
1430 getSymbol(GVar)->print(O, MAI);
1431 O << "[";
1432 if (ElementSize) {
1433 O << ElementSize;
1434 }
1435 O << "]";
1436 break;
1437 default:
1438 llvm_unreachable("type not supported yet");
1439 }
1440 }
1441
printParamName(Function::const_arg_iterator I,int paramIndex,raw_ostream & O)1442 void NVPTXAsmPrinter::printParamName(Function::const_arg_iterator I,
1443 int paramIndex, raw_ostream &O) {
1444 getSymbol(I->getParent())->print(O, MAI);
1445 O << "_param_" << paramIndex;
1446 }
1447
emitFunctionParamList(const Function * F,raw_ostream & O)1448 void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
1449 const DataLayout &DL = getDataLayout();
1450 const AttributeList &PAL = F->getAttributes();
1451 const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
1452 const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering());
1453
1454 Function::const_arg_iterator I, E;
1455 unsigned paramIndex = 0;
1456 bool first = true;
1457 bool isKernelFunc = isKernelFunction(*F);
1458 bool isABI = (STI.getSmVersion() >= 20);
1459 bool hasImageHandles = STI.hasImageHandles();
1460 MVT thePointerTy = TLI->getPointerTy(DL);
1461
1462 if (F->arg_empty()) {
1463 O << "()\n";
1464 return;
1465 }
1466
1467 O << "(\n";
1468
1469 for (I = F->arg_begin(), E = F->arg_end(); I != E; ++I, paramIndex++) {
1470 Type *Ty = I->getType();
1471
1472 if (!first)
1473 O << ",\n";
1474
1475 first = false;
1476
1477 // Handle image/sampler parameters
1478 if (isKernelFunction(*F)) {
1479 if (isSampler(*I) || isImage(*I)) {
1480 if (isImage(*I)) {
1481 std::string sname = std::string(I->getName());
1482 if (isImageWriteOnly(*I) || isImageReadWrite(*I)) {
1483 if (hasImageHandles)
1484 O << "\t.param .u64 .ptr .surfref ";
1485 else
1486 O << "\t.param .surfref ";
1487 CurrentFnSym->print(O, MAI);
1488 O << "_param_" << paramIndex;
1489 }
1490 else { // Default image is read_only
1491 if (hasImageHandles)
1492 O << "\t.param .u64 .ptr .texref ";
1493 else
1494 O << "\t.param .texref ";
1495 CurrentFnSym->print(O, MAI);
1496 O << "_param_" << paramIndex;
1497 }
1498 } else {
1499 if (hasImageHandles)
1500 O << "\t.param .u64 .ptr .samplerref ";
1501 else
1502 O << "\t.param .samplerref ";
1503 CurrentFnSym->print(O, MAI);
1504 O << "_param_" << paramIndex;
1505 }
1506 continue;
1507 }
1508 }
1509
1510 auto getOptimalAlignForParam = [TLI, &DL, &PAL, F,
1511 paramIndex](Type *Ty) -> Align {
1512 Align TypeAlign = TLI->getFunctionParamOptimizedAlign(F, Ty, DL);
1513 MaybeAlign ParamAlign = PAL.getParamAlignment(paramIndex);
1514 return std::max(TypeAlign, ParamAlign.valueOrOne());
1515 };
1516
1517 if (!PAL.hasParamAttr(paramIndex, Attribute::ByVal)) {
1518 if (Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128)) {
1519 // Just print .param .align <a> .b8 .param[size];
1520 // <a> = optimal alignment for the element type; always multiple of
1521 // PAL.getParamAlignment
1522 // size = typeallocsize of element type
1523 Align OptimalAlign = getOptimalAlignForParam(Ty);
1524
1525 O << "\t.param .align " << OptimalAlign.value() << " .b8 ";
1526 printParamName(I, paramIndex, O);
1527 O << "[" << DL.getTypeAllocSize(Ty) << "]";
1528
1529 continue;
1530 }
1531 // Just a scalar
1532 auto *PTy = dyn_cast<PointerType>(Ty);
1533 if (isKernelFunc) {
1534 if (PTy) {
1535 // Special handling for pointer arguments to kernel
1536 O << "\t.param .u" << thePointerTy.getSizeInBits() << " ";
1537
1538 if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() !=
1539 NVPTX::CUDA) {
1540 int addrSpace = PTy->getAddressSpace();
1541 switch (addrSpace) {
1542 default:
1543 O << ".ptr ";
1544 break;
1545 case ADDRESS_SPACE_CONST:
1546 O << ".ptr .const ";
1547 break;
1548 case ADDRESS_SPACE_SHARED:
1549 O << ".ptr .shared ";
1550 break;
1551 case ADDRESS_SPACE_GLOBAL:
1552 O << ".ptr .global ";
1553 break;
1554 }
1555 Align ParamAlign = I->getParamAlign().valueOrOne();
1556 O << ".align " << ParamAlign.value() << " ";
1557 }
1558 printParamName(I, paramIndex, O);
1559 continue;
1560 }
1561
1562 // non-pointer scalar to kernel func
1563 O << "\t.param .";
1564 // Special case: predicate operands become .u8 types
1565 if (Ty->isIntegerTy(1))
1566 O << "u8";
1567 else
1568 O << getPTXFundamentalTypeStr(Ty);
1569 O << " ";
1570 printParamName(I, paramIndex, O);
1571 continue;
1572 }
1573 // Non-kernel function, just print .param .b<size> for ABI
1574 // and .reg .b<size> for non-ABI
1575 unsigned sz = 0;
1576 if (isa<IntegerType>(Ty)) {
1577 sz = cast<IntegerType>(Ty)->getBitWidth();
1578 sz = promoteScalarArgumentSize(sz);
1579 } else if (isa<PointerType>(Ty))
1580 sz = thePointerTy.getSizeInBits();
1581 else if (Ty->isHalfTy())
1582 // PTX ABI requires all scalar parameters to be at least 32
1583 // bits in size. fp16 normally uses .b16 as its storage type
1584 // in PTX, so its size must be adjusted here, too.
1585 sz = 32;
1586 else
1587 sz = Ty->getPrimitiveSizeInBits();
1588 if (isABI)
1589 O << "\t.param .b" << sz << " ";
1590 else
1591 O << "\t.reg .b" << sz << " ";
1592 printParamName(I, paramIndex, O);
1593 continue;
1594 }
1595
1596 // param has byVal attribute.
1597 Type *ETy = PAL.getParamByValType(paramIndex);
1598 assert(ETy && "Param should have byval type");
1599
1600 if (isABI || isKernelFunc) {
1601 // Just print .param .align <a> .b8 .param[size];
1602 // <a> = optimal alignment for the element type; always multiple of
1603 // PAL.getParamAlignment
1604 // size = typeallocsize of element type
1605 Align OptimalAlign = getOptimalAlignForParam(ETy);
1606
1607 // Work around a bug in ptxas. When PTX code takes address of
1608 // byval parameter with alignment < 4, ptxas generates code to
1609 // spill argument into memory. Alas on sm_50+ ptxas generates
1610 // SASS code that fails with misaligned access. To work around
1611 // the problem, make sure that we align byval parameters by at
1612 // least 4. Matching change must be made in LowerCall() where we
1613 // prepare parameters for the call.
1614 //
1615 // TODO: this will need to be undone when we get to support multi-TU
1616 // device-side compilation as it breaks ABI compatibility with nvcc.
1617 // Hopefully ptxas bug is fixed by then.
1618 if (!isKernelFunc && OptimalAlign < Align(4))
1619 OptimalAlign = Align(4);
1620 unsigned sz = DL.getTypeAllocSize(ETy);
1621 O << "\t.param .align " << OptimalAlign.value() << " .b8 ";
1622 printParamName(I, paramIndex, O);
1623 O << "[" << sz << "]";
1624 continue;
1625 } else {
1626 // Split the ETy into constituent parts and
1627 // print .param .b<size> <name> for each part.
1628 // Further, if a part is vector, print the above for
1629 // each vector element.
1630 SmallVector<EVT, 16> vtparts;
1631 ComputeValueVTs(*TLI, DL, ETy, vtparts);
1632 for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
1633 unsigned elems = 1;
1634 EVT elemtype = vtparts[i];
1635 if (vtparts[i].isVector()) {
1636 elems = vtparts[i].getVectorNumElements();
1637 elemtype = vtparts[i].getVectorElementType();
1638 }
1639
1640 for (unsigned j = 0, je = elems; j != je; ++j) {
1641 unsigned sz = elemtype.getSizeInBits();
1642 if (elemtype.isInteger())
1643 sz = promoteScalarArgumentSize(sz);
1644 O << "\t.reg .b" << sz << " ";
1645 printParamName(I, paramIndex, O);
1646 if (j < je - 1)
1647 O << ",\n";
1648 ++paramIndex;
1649 }
1650 if (i < e - 1)
1651 O << ",\n";
1652 }
1653 --paramIndex;
1654 continue;
1655 }
1656 }
1657
1658 O << "\n)\n";
1659 }
1660
emitFunctionParamList(const MachineFunction & MF,raw_ostream & O)1661 void NVPTXAsmPrinter::emitFunctionParamList(const MachineFunction &MF,
1662 raw_ostream &O) {
1663 const Function &F = MF.getFunction();
1664 emitFunctionParamList(&F, O);
1665 }
1666
setAndEmitFunctionVirtualRegisters(const MachineFunction & MF)1667 void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters(
1668 const MachineFunction &MF) {
1669 SmallString<128> Str;
1670 raw_svector_ostream O(Str);
1671
1672 // Map the global virtual register number to a register class specific
1673 // virtual register number starting from 1 with that class.
1674 const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo();
1675 //unsigned numRegClasses = TRI->getNumRegClasses();
1676
1677 // Emit the Fake Stack Object
1678 const MachineFrameInfo &MFI = MF.getFrameInfo();
1679 int NumBytes = (int) MFI.getStackSize();
1680 if (NumBytes) {
1681 O << "\t.local .align " << MFI.getMaxAlign().value() << " .b8 \t"
1682 << DEPOTNAME << getFunctionNumber() << "[" << NumBytes << "];\n";
1683 if (static_cast<const NVPTXTargetMachine &>(MF.getTarget()).is64Bit()) {
1684 O << "\t.reg .b64 \t%SP;\n";
1685 O << "\t.reg .b64 \t%SPL;\n";
1686 } else {
1687 O << "\t.reg .b32 \t%SP;\n";
1688 O << "\t.reg .b32 \t%SPL;\n";
1689 }
1690 }
1691
1692 // Go through all virtual registers to establish the mapping between the
1693 // global virtual
1694 // register number and the per class virtual register number.
1695 // We use the per class virtual register number in the ptx output.
1696 unsigned int numVRs = MRI->getNumVirtRegs();
1697 for (unsigned i = 0; i < numVRs; i++) {
1698 Register vr = Register::index2VirtReg(i);
1699 const TargetRegisterClass *RC = MRI->getRegClass(vr);
1700 DenseMap<unsigned, unsigned> ®map = VRegMapping[RC];
1701 int n = regmap.size();
1702 regmap.insert(std::make_pair(vr, n + 1));
1703 }
1704
1705 // Emit register declarations
1706 // @TODO: Extract out the real register usage
1707 // O << "\t.reg .pred %p<" << NVPTXNumRegisters << ">;\n";
1708 // O << "\t.reg .s16 %rc<" << NVPTXNumRegisters << ">;\n";
1709 // O << "\t.reg .s16 %rs<" << NVPTXNumRegisters << ">;\n";
1710 // O << "\t.reg .s32 %r<" << NVPTXNumRegisters << ">;\n";
1711 // O << "\t.reg .s64 %rd<" << NVPTXNumRegisters << ">;\n";
1712 // O << "\t.reg .f32 %f<" << NVPTXNumRegisters << ">;\n";
1713 // O << "\t.reg .f64 %fd<" << NVPTXNumRegisters << ">;\n";
1714
1715 // Emit declaration of the virtual registers or 'physical' registers for
1716 // each register class
1717 for (unsigned i=0; i< TRI->getNumRegClasses(); i++) {
1718 const TargetRegisterClass *RC = TRI->getRegClass(i);
1719 DenseMap<unsigned, unsigned> ®map = VRegMapping[RC];
1720 std::string rcname = getNVPTXRegClassName(RC);
1721 std::string rcStr = getNVPTXRegClassStr(RC);
1722 int n = regmap.size();
1723
1724 // Only declare those registers that may be used.
1725 if (n) {
1726 O << "\t.reg " << rcname << " \t" << rcStr << "<" << (n+1)
1727 << ">;\n";
1728 }
1729 }
1730
1731 OutStreamer->emitRawText(O.str());
1732 }
1733
printFPConstant(const ConstantFP * Fp,raw_ostream & O)1734 void NVPTXAsmPrinter::printFPConstant(const ConstantFP *Fp, raw_ostream &O) {
1735 APFloat APF = APFloat(Fp->getValueAPF()); // make a copy
1736 bool ignored;
1737 unsigned int numHex;
1738 const char *lead;
1739
1740 if (Fp->getType()->getTypeID() == Type::FloatTyID) {
1741 numHex = 8;
1742 lead = "0f";
1743 APF.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, &ignored);
1744 } else if (Fp->getType()->getTypeID() == Type::DoubleTyID) {
1745 numHex = 16;
1746 lead = "0d";
1747 APF.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, &ignored);
1748 } else
1749 llvm_unreachable("unsupported fp type");
1750
1751 APInt API = APF.bitcastToAPInt();
1752 O << lead << format_hex_no_prefix(API.getZExtValue(), numHex, /*Upper=*/true);
1753 }
1754
printScalarConstant(const Constant * CPV,raw_ostream & O)1755 void NVPTXAsmPrinter::printScalarConstant(const Constant *CPV, raw_ostream &O) {
1756 if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
1757 O << CI->getValue();
1758 return;
1759 }
1760 if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CPV)) {
1761 printFPConstant(CFP, O);
1762 return;
1763 }
1764 if (isa<ConstantPointerNull>(CPV)) {
1765 O << "0";
1766 return;
1767 }
1768 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) {
1769 bool IsNonGenericPointer = false;
1770 if (GVar->getType()->getAddressSpace() != 0) {
1771 IsNonGenericPointer = true;
1772 }
1773 if (EmitGeneric && !isa<Function>(CPV) && !IsNonGenericPointer) {
1774 O << "generic(";
1775 getSymbol(GVar)->print(O, MAI);
1776 O << ")";
1777 } else {
1778 getSymbol(GVar)->print(O, MAI);
1779 }
1780 return;
1781 }
1782 if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1783 const Value *v = Cexpr->stripPointerCasts();
1784 PointerType *PTy = dyn_cast<PointerType>(Cexpr->getType());
1785 bool IsNonGenericPointer = false;
1786 if (PTy && PTy->getAddressSpace() != 0) {
1787 IsNonGenericPointer = true;
1788 }
1789 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(v)) {
1790 if (EmitGeneric && !isa<Function>(v) && !IsNonGenericPointer) {
1791 O << "generic(";
1792 getSymbol(GVar)->print(O, MAI);
1793 O << ")";
1794 } else {
1795 getSymbol(GVar)->print(O, MAI);
1796 }
1797 return;
1798 } else {
1799 lowerConstant(CPV)->print(O, MAI);
1800 return;
1801 }
1802 }
1803 llvm_unreachable("Not scalar type found in printScalarConstant()");
1804 }
1805
bufferLEByte(const Constant * CPV,int Bytes,AggBuffer * AggBuffer)1806 void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes,
1807 AggBuffer *AggBuffer) {
1808 const DataLayout &DL = getDataLayout();
1809 int AllocSize = DL.getTypeAllocSize(CPV->getType());
1810 if (isa<UndefValue>(CPV) || CPV->isNullValue()) {
1811 // Non-zero Bytes indicates that we need to zero-fill everything. Otherwise,
1812 // only the space allocated by CPV.
1813 AggBuffer->addZeros(Bytes ? Bytes : AllocSize);
1814 return;
1815 }
1816
1817 // Helper for filling AggBuffer with APInts.
1818 auto AddIntToBuffer = [AggBuffer, Bytes](const APInt &Val) {
1819 size_t NumBytes = (Val.getBitWidth() + 7) / 8;
1820 SmallVector<unsigned char, 16> Buf(NumBytes);
1821 for (unsigned I = 0; I < NumBytes; ++I) {
1822 Buf[I] = Val.extractBitsAsZExtValue(8, I * 8);
1823 }
1824 AggBuffer->addBytes(Buf.data(), NumBytes, Bytes);
1825 };
1826
1827 switch (CPV->getType()->getTypeID()) {
1828 case Type::IntegerTyID:
1829 if (const auto CI = dyn_cast<ConstantInt>(CPV)) {
1830 AddIntToBuffer(CI->getValue());
1831 break;
1832 }
1833 if (const auto *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1834 if (const auto *CI =
1835 dyn_cast<ConstantInt>(ConstantFoldConstant(Cexpr, DL))) {
1836 AddIntToBuffer(CI->getValue());
1837 break;
1838 }
1839 if (Cexpr->getOpcode() == Instruction::PtrToInt) {
1840 Value *V = Cexpr->getOperand(0)->stripPointerCasts();
1841 AggBuffer->addSymbol(V, Cexpr->getOperand(0));
1842 AggBuffer->addZeros(AllocSize);
1843 break;
1844 }
1845 }
1846 llvm_unreachable("unsupported integer const type");
1847 break;
1848
1849 case Type::HalfTyID:
1850 case Type::FloatTyID:
1851 case Type::DoubleTyID:
1852 AddIntToBuffer(cast<ConstantFP>(CPV)->getValueAPF().bitcastToAPInt());
1853 break;
1854
1855 case Type::PointerTyID: {
1856 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) {
1857 AggBuffer->addSymbol(GVar, GVar);
1858 } else if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1859 const Value *v = Cexpr->stripPointerCasts();
1860 AggBuffer->addSymbol(v, Cexpr);
1861 }
1862 AggBuffer->addZeros(AllocSize);
1863 break;
1864 }
1865
1866 case Type::ArrayTyID:
1867 case Type::FixedVectorTyID:
1868 case Type::StructTyID: {
1869 if (isa<ConstantAggregate>(CPV) || isa<ConstantDataSequential>(CPV)) {
1870 bufferAggregateConstant(CPV, AggBuffer);
1871 if (Bytes > AllocSize)
1872 AggBuffer->addZeros(Bytes - AllocSize);
1873 } else if (isa<ConstantAggregateZero>(CPV))
1874 AggBuffer->addZeros(Bytes);
1875 else
1876 llvm_unreachable("Unexpected Constant type");
1877 break;
1878 }
1879
1880 default:
1881 llvm_unreachable("unsupported type");
1882 }
1883 }
1884
bufferAggregateConstant(const Constant * CPV,AggBuffer * aggBuffer)1885 void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
1886 AggBuffer *aggBuffer) {
1887 const DataLayout &DL = getDataLayout();
1888 int Bytes;
1889
1890 // Integers of arbitrary width
1891 if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
1892 APInt Val = CI->getValue();
1893 for (unsigned I = 0, E = DL.getTypeAllocSize(CPV->getType()); I < E; ++I) {
1894 uint8_t Byte = Val.getLoBits(8).getZExtValue();
1895 aggBuffer->addBytes(&Byte, 1, 1);
1896 Val.lshrInPlace(8);
1897 }
1898 return;
1899 }
1900
1901 // Old constants
1902 if (isa<ConstantArray>(CPV) || isa<ConstantVector>(CPV)) {
1903 if (CPV->getNumOperands())
1904 for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i)
1905 bufferLEByte(cast<Constant>(CPV->getOperand(i)), 0, aggBuffer);
1906 return;
1907 }
1908
1909 if (const ConstantDataSequential *CDS =
1910 dyn_cast<ConstantDataSequential>(CPV)) {
1911 if (CDS->getNumElements())
1912 for (unsigned i = 0; i < CDS->getNumElements(); ++i)
1913 bufferLEByte(cast<Constant>(CDS->getElementAsConstant(i)), 0,
1914 aggBuffer);
1915 return;
1916 }
1917
1918 if (isa<ConstantStruct>(CPV)) {
1919 if (CPV->getNumOperands()) {
1920 StructType *ST = cast<StructType>(CPV->getType());
1921 for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) {
1922 if (i == (e - 1))
1923 Bytes = DL.getStructLayout(ST)->getElementOffset(0) +
1924 DL.getTypeAllocSize(ST) -
1925 DL.getStructLayout(ST)->getElementOffset(i);
1926 else
1927 Bytes = DL.getStructLayout(ST)->getElementOffset(i + 1) -
1928 DL.getStructLayout(ST)->getElementOffset(i);
1929 bufferLEByte(cast<Constant>(CPV->getOperand(i)), Bytes, aggBuffer);
1930 }
1931 }
1932 return;
1933 }
1934 llvm_unreachable("unsupported constant type in printAggregateConstant()");
1935 }
1936
1937 /// lowerConstantForGV - Return an MCExpr for the given Constant. This is mostly
1938 /// a copy from AsmPrinter::lowerConstant, except customized to only handle
1939 /// expressions that are representable in PTX and create
1940 /// NVPTXGenericMCSymbolRefExpr nodes for addrspacecast instructions.
1941 const MCExpr *
lowerConstantForGV(const Constant * CV,bool ProcessingGeneric)1942 NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV, bool ProcessingGeneric) {
1943 MCContext &Ctx = OutContext;
1944
1945 if (CV->isNullValue() || isa<UndefValue>(CV))
1946 return MCConstantExpr::create(0, Ctx);
1947
1948 if (const ConstantInt *CI = dyn_cast<ConstantInt>(CV))
1949 return MCConstantExpr::create(CI->getZExtValue(), Ctx);
1950
1951 if (const GlobalValue *GV = dyn_cast<GlobalValue>(CV)) {
1952 const MCSymbolRefExpr *Expr =
1953 MCSymbolRefExpr::create(getSymbol(GV), Ctx);
1954 if (ProcessingGeneric) {
1955 return NVPTXGenericMCSymbolRefExpr::create(Expr, Ctx);
1956 } else {
1957 return Expr;
1958 }
1959 }
1960
1961 const ConstantExpr *CE = dyn_cast<ConstantExpr>(CV);
1962 if (!CE) {
1963 llvm_unreachable("Unknown constant value to lower!");
1964 }
1965
1966 switch (CE->getOpcode()) {
1967 default: {
1968 // If the code isn't optimized, there may be outstanding folding
1969 // opportunities. Attempt to fold the expression using DataLayout as a
1970 // last resort before giving up.
1971 Constant *C = ConstantFoldConstant(CE, getDataLayout());
1972 if (C != CE)
1973 return lowerConstantForGV(C, ProcessingGeneric);
1974
1975 // Otherwise report the problem to the user.
1976 std::string S;
1977 raw_string_ostream OS(S);
1978 OS << "Unsupported expression in static initializer: ";
1979 CE->printAsOperand(OS, /*PrintType=*/false,
1980 !MF ? nullptr : MF->getFunction().getParent());
1981 report_fatal_error(Twine(OS.str()));
1982 }
1983
1984 case Instruction::AddrSpaceCast: {
1985 // Strip the addrspacecast and pass along the operand
1986 PointerType *DstTy = cast<PointerType>(CE->getType());
1987 if (DstTy->getAddressSpace() == 0) {
1988 return lowerConstantForGV(cast<const Constant>(CE->getOperand(0)), true);
1989 }
1990 std::string S;
1991 raw_string_ostream OS(S);
1992 OS << "Unsupported expression in static initializer: ";
1993 CE->printAsOperand(OS, /*PrintType=*/ false,
1994 !MF ? nullptr : MF->getFunction().getParent());
1995 report_fatal_error(Twine(OS.str()));
1996 }
1997
1998 case Instruction::GetElementPtr: {
1999 const DataLayout &DL = getDataLayout();
2000
2001 // Generate a symbolic expression for the byte address
2002 APInt OffsetAI(DL.getPointerTypeSizeInBits(CE->getType()), 0);
2003 cast<GEPOperator>(CE)->accumulateConstantOffset(DL, OffsetAI);
2004
2005 const MCExpr *Base = lowerConstantForGV(CE->getOperand(0),
2006 ProcessingGeneric);
2007 if (!OffsetAI)
2008 return Base;
2009
2010 int64_t Offset = OffsetAI.getSExtValue();
2011 return MCBinaryExpr::createAdd(Base, MCConstantExpr::create(Offset, Ctx),
2012 Ctx);
2013 }
2014
2015 case Instruction::Trunc:
2016 // We emit the value and depend on the assembler to truncate the generated
2017 // expression properly. This is important for differences between
2018 // blockaddress labels. Since the two labels are in the same function, it
2019 // is reasonable to treat their delta as a 32-bit value.
2020 LLVM_FALLTHROUGH;
2021 case Instruction::BitCast:
2022 return lowerConstantForGV(CE->getOperand(0), ProcessingGeneric);
2023
2024 case Instruction::IntToPtr: {
2025 const DataLayout &DL = getDataLayout();
2026
2027 // Handle casts to pointers by changing them into casts to the appropriate
2028 // integer type. This promotes constant folding and simplifies this code.
2029 Constant *Op = CE->getOperand(0);
2030 Op = ConstantExpr::getIntegerCast(Op, DL.getIntPtrType(CV->getType()),
2031 false/*ZExt*/);
2032 return lowerConstantForGV(Op, ProcessingGeneric);
2033 }
2034
2035 case Instruction::PtrToInt: {
2036 const DataLayout &DL = getDataLayout();
2037
2038 // Support only foldable casts to/from pointers that can be eliminated by
2039 // changing the pointer to the appropriately sized integer type.
2040 Constant *Op = CE->getOperand(0);
2041 Type *Ty = CE->getType();
2042
2043 const MCExpr *OpExpr = lowerConstantForGV(Op, ProcessingGeneric);
2044
2045 // We can emit the pointer value into this slot if the slot is an
2046 // integer slot equal to the size of the pointer.
2047 if (DL.getTypeAllocSize(Ty) == DL.getTypeAllocSize(Op->getType()))
2048 return OpExpr;
2049
2050 // Otherwise the pointer is smaller than the resultant integer, mask off
2051 // the high bits so we are sure to get a proper truncation if the input is
2052 // a constant expr.
2053 unsigned InBits = DL.getTypeAllocSizeInBits(Op->getType());
2054 const MCExpr *MaskExpr = MCConstantExpr::create(~0ULL >> (64-InBits), Ctx);
2055 return MCBinaryExpr::createAnd(OpExpr, MaskExpr, Ctx);
2056 }
2057
2058 // The MC library also has a right-shift operator, but it isn't consistently
2059 // signed or unsigned between different targets.
2060 case Instruction::Add: {
2061 const MCExpr *LHS = lowerConstantForGV(CE->getOperand(0), ProcessingGeneric);
2062 const MCExpr *RHS = lowerConstantForGV(CE->getOperand(1), ProcessingGeneric);
2063 switch (CE->getOpcode()) {
2064 default: llvm_unreachable("Unknown binary operator constant cast expr");
2065 case Instruction::Add: return MCBinaryExpr::createAdd(LHS, RHS, Ctx);
2066 }
2067 }
2068 }
2069 }
2070
2071 // Copy of MCExpr::print customized for NVPTX
printMCExpr(const MCExpr & Expr,raw_ostream & OS)2072 void NVPTXAsmPrinter::printMCExpr(const MCExpr &Expr, raw_ostream &OS) {
2073 switch (Expr.getKind()) {
2074 case MCExpr::Target:
2075 return cast<MCTargetExpr>(&Expr)->printImpl(OS, MAI);
2076 case MCExpr::Constant:
2077 OS << cast<MCConstantExpr>(Expr).getValue();
2078 return;
2079
2080 case MCExpr::SymbolRef: {
2081 const MCSymbolRefExpr &SRE = cast<MCSymbolRefExpr>(Expr);
2082 const MCSymbol &Sym = SRE.getSymbol();
2083 Sym.print(OS, MAI);
2084 return;
2085 }
2086
2087 case MCExpr::Unary: {
2088 const MCUnaryExpr &UE = cast<MCUnaryExpr>(Expr);
2089 switch (UE.getOpcode()) {
2090 case MCUnaryExpr::LNot: OS << '!'; break;
2091 case MCUnaryExpr::Minus: OS << '-'; break;
2092 case MCUnaryExpr::Not: OS << '~'; break;
2093 case MCUnaryExpr::Plus: OS << '+'; break;
2094 }
2095 printMCExpr(*UE.getSubExpr(), OS);
2096 return;
2097 }
2098
2099 case MCExpr::Binary: {
2100 const MCBinaryExpr &BE = cast<MCBinaryExpr>(Expr);
2101
2102 // Only print parens around the LHS if it is non-trivial.
2103 if (isa<MCConstantExpr>(BE.getLHS()) || isa<MCSymbolRefExpr>(BE.getLHS()) ||
2104 isa<NVPTXGenericMCSymbolRefExpr>(BE.getLHS())) {
2105 printMCExpr(*BE.getLHS(), OS);
2106 } else {
2107 OS << '(';
2108 printMCExpr(*BE.getLHS(), OS);
2109 OS<< ')';
2110 }
2111
2112 switch (BE.getOpcode()) {
2113 case MCBinaryExpr::Add:
2114 // Print "X-42" instead of "X+-42".
2115 if (const MCConstantExpr *RHSC = dyn_cast<MCConstantExpr>(BE.getRHS())) {
2116 if (RHSC->getValue() < 0) {
2117 OS << RHSC->getValue();
2118 return;
2119 }
2120 }
2121
2122 OS << '+';
2123 break;
2124 default: llvm_unreachable("Unhandled binary operator");
2125 }
2126
2127 // Only print parens around the LHS if it is non-trivial.
2128 if (isa<MCConstantExpr>(BE.getRHS()) || isa<MCSymbolRefExpr>(BE.getRHS())) {
2129 printMCExpr(*BE.getRHS(), OS);
2130 } else {
2131 OS << '(';
2132 printMCExpr(*BE.getRHS(), OS);
2133 OS << ')';
2134 }
2135 return;
2136 }
2137 }
2138
2139 llvm_unreachable("Invalid expression kind!");
2140 }
2141
2142 /// PrintAsmOperand - Print out an operand for an inline asm expression.
2143 ///
PrintAsmOperand(const MachineInstr * MI,unsigned OpNo,const char * ExtraCode,raw_ostream & O)2144 bool NVPTXAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,
2145 const char *ExtraCode, raw_ostream &O) {
2146 if (ExtraCode && ExtraCode[0]) {
2147 if (ExtraCode[1] != 0)
2148 return true; // Unknown modifier.
2149
2150 switch (ExtraCode[0]) {
2151 default:
2152 // See if this is a generic print operand
2153 return AsmPrinter::PrintAsmOperand(MI, OpNo, ExtraCode, O);
2154 case 'r':
2155 break;
2156 }
2157 }
2158
2159 printOperand(MI, OpNo, O);
2160
2161 return false;
2162 }
2163
PrintAsmMemoryOperand(const MachineInstr * MI,unsigned OpNo,const char * ExtraCode,raw_ostream & O)2164 bool NVPTXAsmPrinter::PrintAsmMemoryOperand(const MachineInstr *MI,
2165 unsigned OpNo,
2166 const char *ExtraCode,
2167 raw_ostream &O) {
2168 if (ExtraCode && ExtraCode[0])
2169 return true; // Unknown modifier
2170
2171 O << '[';
2172 printMemOperand(MI, OpNo, O);
2173 O << ']';
2174
2175 return false;
2176 }
2177
printOperand(const MachineInstr * MI,int opNum,raw_ostream & O)2178 void NVPTXAsmPrinter::printOperand(const MachineInstr *MI, int opNum,
2179 raw_ostream &O) {
2180 const MachineOperand &MO = MI->getOperand(opNum);
2181 switch (MO.getType()) {
2182 case MachineOperand::MO_Register:
2183 if (Register::isPhysicalRegister(MO.getReg())) {
2184 if (MO.getReg() == NVPTX::VRDepot)
2185 O << DEPOTNAME << getFunctionNumber();
2186 else
2187 O << NVPTXInstPrinter::getRegisterName(MO.getReg());
2188 } else {
2189 emitVirtualRegister(MO.getReg(), O);
2190 }
2191 break;
2192
2193 case MachineOperand::MO_Immediate:
2194 O << MO.getImm();
2195 break;
2196
2197 case MachineOperand::MO_FPImmediate:
2198 printFPConstant(MO.getFPImm(), O);
2199 break;
2200
2201 case MachineOperand::MO_GlobalAddress:
2202 PrintSymbolOperand(MO, O);
2203 break;
2204
2205 case MachineOperand::MO_MachineBasicBlock:
2206 MO.getMBB()->getSymbol()->print(O, MAI);
2207 break;
2208
2209 default:
2210 llvm_unreachable("Operand type not supported.");
2211 }
2212 }
2213
printMemOperand(const MachineInstr * MI,int opNum,raw_ostream & O,const char * Modifier)2214 void NVPTXAsmPrinter::printMemOperand(const MachineInstr *MI, int opNum,
2215 raw_ostream &O, const char *Modifier) {
2216 printOperand(MI, opNum, O);
2217
2218 if (Modifier && strcmp(Modifier, "add") == 0) {
2219 O << ", ";
2220 printOperand(MI, opNum + 1, O);
2221 } else {
2222 if (MI->getOperand(opNum + 1).isImm() &&
2223 MI->getOperand(opNum + 1).getImm() == 0)
2224 return; // don't print ',0' or '+0'
2225 O << "+";
2226 printOperand(MI, opNum + 1, O);
2227 }
2228 }
2229
2230 // Force static initialization.
LLVMInitializeNVPTXAsmPrinter()2231 extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXAsmPrinter() {
2232 RegisterAsmPrinter<NVPTXAsmPrinter> X(getTheNVPTXTarget32());
2233 RegisterAsmPrinter<NVPTXAsmPrinter> Y(getTheNVPTXTarget64());
2234 }
2235