1 //===--- SPIRVCallLowering.cpp - Call lowering ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the lowering of LLVM calls to machine code calls for
10 // GlobalISel.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "SPIRVCallLowering.h"
15 #include "MCTargetDesc/SPIRVBaseInfo.h"
16 #include "SPIRV.h"
17 #include "SPIRVGlobalRegistry.h"
18 #include "SPIRVISelLowering.h"
19 #include "SPIRVRegisterInfo.h"
20 #include "SPIRVSubtarget.h"
21 #include "SPIRVUtils.h"
22 #include "llvm/CodeGen/FunctionLoweringInfo.h"
23
24 using namespace llvm;
25
SPIRVCallLowering(const SPIRVTargetLowering & TLI,SPIRVGlobalRegistry * GR)26 SPIRVCallLowering::SPIRVCallLowering(const SPIRVTargetLowering &TLI,
27 SPIRVGlobalRegistry *GR)
28 : CallLowering(&TLI), GR(GR) {}
29
lowerReturn(MachineIRBuilder & MIRBuilder,const Value * Val,ArrayRef<Register> VRegs,FunctionLoweringInfo & FLI,Register SwiftErrorVReg) const30 bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
31 const Value *Val, ArrayRef<Register> VRegs,
32 FunctionLoweringInfo &FLI,
33 Register SwiftErrorVReg) const {
34 // Currently all return types should use a single register.
35 // TODO: handle the case of multiple registers.
36 if (VRegs.size() > 1)
37 return false;
38 if (Val) {
39 const auto &STI = MIRBuilder.getMF().getSubtarget();
40 return MIRBuilder.buildInstr(SPIRV::OpReturnValue)
41 .addUse(VRegs[0])
42 .constrainAllUses(MIRBuilder.getTII(), *STI.getRegisterInfo(),
43 *STI.getRegBankInfo());
44 }
45 MIRBuilder.buildInstr(SPIRV::OpReturn);
46 return true;
47 }
48
49 // Based on the LLVM function attributes, get a SPIR-V FunctionControl.
getFunctionControl(const Function & F)50 static uint32_t getFunctionControl(const Function &F) {
51 uint32_t FuncControl = static_cast<uint32_t>(SPIRV::FunctionControl::None);
52 if (F.hasFnAttribute(Attribute::AttrKind::AlwaysInline)) {
53 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Inline);
54 }
55 if (F.hasFnAttribute(Attribute::AttrKind::ReadNone)) {
56 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Pure);
57 }
58 if (F.hasFnAttribute(Attribute::AttrKind::ReadOnly)) {
59 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Const);
60 }
61 if (F.hasFnAttribute(Attribute::AttrKind::NoInline)) {
62 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::DontInline);
63 }
64 return FuncControl;
65 }
66
getConstInt(MDNode * MD,unsigned NumOp)67 static ConstantInt *getConstInt(MDNode *MD, unsigned NumOp) {
68 if (MD->getNumOperands() > NumOp) {
69 auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(NumOp));
70 if (CMeta)
71 return dyn_cast<ConstantInt>(CMeta->getValue());
72 }
73 return nullptr;
74 }
75
76 // This code restores function args/retvalue types for composite cases
77 // because the final types should still be aggregate whereas they're i32
78 // during the translation to cope with aggregate flattening etc.
getOriginalFunctionType(const Function & F)79 static FunctionType *getOriginalFunctionType(const Function &F) {
80 auto *NamedMD = F.getParent()->getNamedMetadata("spv.cloned_funcs");
81 if (NamedMD == nullptr)
82 return F.getFunctionType();
83
84 Type *RetTy = F.getFunctionType()->getReturnType();
85 SmallVector<Type *, 4> ArgTypes;
86 for (auto &Arg : F.args())
87 ArgTypes.push_back(Arg.getType());
88
89 auto ThisFuncMDIt =
90 std::find_if(NamedMD->op_begin(), NamedMD->op_end(), [&F](MDNode *N) {
91 return isa<MDString>(N->getOperand(0)) &&
92 cast<MDString>(N->getOperand(0))->getString() == F.getName();
93 });
94 // TODO: probably one function can have numerous type mutations,
95 // so we should support this.
96 if (ThisFuncMDIt != NamedMD->op_end()) {
97 auto *ThisFuncMD = *ThisFuncMDIt;
98 MDNode *MD = dyn_cast<MDNode>(ThisFuncMD->getOperand(1));
99 assert(MD && "MDNode operand is expected");
100 ConstantInt *Const = getConstInt(MD, 0);
101 if (Const) {
102 auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(1));
103 assert(CMeta && "ConstantAsMetadata operand is expected");
104 assert(Const->getSExtValue() >= -1);
105 // Currently -1 indicates return value, greater values mean
106 // argument numbers.
107 if (Const->getSExtValue() == -1)
108 RetTy = CMeta->getType();
109 else
110 ArgTypes[Const->getSExtValue()] = CMeta->getType();
111 }
112 }
113
114 return FunctionType::get(RetTy, ArgTypes, F.isVarArg());
115 }
116
lowerFormalArguments(MachineIRBuilder & MIRBuilder,const Function & F,ArrayRef<ArrayRef<Register>> VRegs,FunctionLoweringInfo & FLI) const117 bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
118 const Function &F,
119 ArrayRef<ArrayRef<Register>> VRegs,
120 FunctionLoweringInfo &FLI) const {
121 assert(GR && "Must initialize the SPIRV type registry before lowering args.");
122 GR->setCurrentFunc(MIRBuilder.getMF());
123
124 // Assign types and names to all args, and store their types for later.
125 FunctionType *FTy = getOriginalFunctionType(F);
126 SmallVector<SPIRVType *, 4> ArgTypeVRegs;
127 if (VRegs.size() > 0) {
128 unsigned i = 0;
129 for (const auto &Arg : F.args()) {
130 // Currently formal args should use single registers.
131 // TODO: handle the case of multiple registers.
132 if (VRegs[i].size() > 1)
133 return false;
134 Type *ArgTy = FTy->getParamType(i);
135 SPIRV::AccessQualifier AQ = SPIRV::AccessQualifier::ReadWrite;
136 MDNode *Node = F.getMetadata("kernel_arg_access_qual");
137 if (Node && i < Node->getNumOperands()) {
138 StringRef AQString = cast<MDString>(Node->getOperand(i))->getString();
139 if (AQString.compare("read_only") == 0)
140 AQ = SPIRV::AccessQualifier::ReadOnly;
141 else if (AQString.compare("write_only") == 0)
142 AQ = SPIRV::AccessQualifier::WriteOnly;
143 }
144 auto *SpirvTy = GR->assignTypeToVReg(ArgTy, VRegs[i][0], MIRBuilder, AQ);
145 ArgTypeVRegs.push_back(SpirvTy);
146
147 if (Arg.hasName())
148 buildOpName(VRegs[i][0], Arg.getName(), MIRBuilder);
149 if (Arg.getType()->isPointerTy()) {
150 auto DerefBytes = static_cast<unsigned>(Arg.getDereferenceableBytes());
151 if (DerefBytes != 0)
152 buildOpDecorate(VRegs[i][0], MIRBuilder,
153 SPIRV::Decoration::MaxByteOffset, {DerefBytes});
154 }
155 if (Arg.hasAttribute(Attribute::Alignment)) {
156 auto Alignment = static_cast<unsigned>(
157 Arg.getAttribute(Attribute::Alignment).getValueAsInt());
158 buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::Alignment,
159 {Alignment});
160 }
161 if (Arg.hasAttribute(Attribute::ReadOnly)) {
162 auto Attr =
163 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoWrite);
164 buildOpDecorate(VRegs[i][0], MIRBuilder,
165 SPIRV::Decoration::FuncParamAttr, {Attr});
166 }
167 if (Arg.hasAttribute(Attribute::ZExt)) {
168 auto Attr =
169 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::Zext);
170 buildOpDecorate(VRegs[i][0], MIRBuilder,
171 SPIRV::Decoration::FuncParamAttr, {Attr});
172 }
173 if (Arg.hasAttribute(Attribute::NoAlias)) {
174 auto Attr =
175 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoAlias);
176 buildOpDecorate(VRegs[i][0], MIRBuilder,
177 SPIRV::Decoration::FuncParamAttr, {Attr});
178 }
179 Node = F.getMetadata("kernel_arg_type_qual");
180 if (Node && i < Node->getNumOperands()) {
181 StringRef TypeQual = cast<MDString>(Node->getOperand(i))->getString();
182 if (TypeQual.compare("volatile") == 0)
183 buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::Volatile,
184 {});
185 }
186 Node = F.getMetadata("spirv.ParameterDecorations");
187 if (Node && i < Node->getNumOperands() &&
188 isa<MDNode>(Node->getOperand(i))) {
189 MDNode *MD = cast<MDNode>(Node->getOperand(i));
190 for (const MDOperand &MDOp : MD->operands()) {
191 MDNode *MD2 = dyn_cast<MDNode>(MDOp);
192 assert(MD2 && "Metadata operand is expected");
193 ConstantInt *Const = getConstInt(MD2, 0);
194 assert(Const && "MDOperand should be ConstantInt");
195 auto Dec = static_cast<SPIRV::Decoration>(Const->getZExtValue());
196 std::vector<uint32_t> DecVec;
197 for (unsigned j = 1; j < MD2->getNumOperands(); j++) {
198 ConstantInt *Const = getConstInt(MD2, j);
199 assert(Const && "MDOperand should be ConstantInt");
200 DecVec.push_back(static_cast<uint32_t>(Const->getZExtValue()));
201 }
202 buildOpDecorate(VRegs[i][0], MIRBuilder, Dec, DecVec);
203 }
204 }
205 ++i;
206 }
207 }
208
209 // Generate a SPIR-V type for the function.
210 auto MRI = MIRBuilder.getMRI();
211 Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
212 MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass);
213 if (F.isDeclaration())
214 GR->add(&F, &MIRBuilder.getMF(), FuncVReg);
215 SPIRVType *RetTy = GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder);
216 SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
217 FTy, RetTy, ArgTypeVRegs, MIRBuilder);
218
219 // Build the OpTypeFunction declaring it.
220 uint32_t FuncControl = getFunctionControl(F);
221
222 MIRBuilder.buildInstr(SPIRV::OpFunction)
223 .addDef(FuncVReg)
224 .addUse(GR->getSPIRVTypeID(RetTy))
225 .addImm(FuncControl)
226 .addUse(GR->getSPIRVTypeID(FuncTy));
227
228 // Add OpFunctionParameters.
229 int i = 0;
230 for (const auto &Arg : F.args()) {
231 assert(VRegs[i].size() == 1 && "Formal arg has multiple vregs");
232 MRI->setRegClass(VRegs[i][0], &SPIRV::IDRegClass);
233 MIRBuilder.buildInstr(SPIRV::OpFunctionParameter)
234 .addDef(VRegs[i][0])
235 .addUse(GR->getSPIRVTypeID(ArgTypeVRegs[i]));
236 if (F.isDeclaration())
237 GR->add(&Arg, &MIRBuilder.getMF(), VRegs[i][0]);
238 i++;
239 }
240 // Name the function.
241 if (F.hasName())
242 buildOpName(FuncVReg, F.getName(), MIRBuilder);
243
244 // Handle entry points and function linkage.
245 if (F.getCallingConv() == CallingConv::SPIR_KERNEL) {
246 auto MIB = MIRBuilder.buildInstr(SPIRV::OpEntryPoint)
247 .addImm(static_cast<uint32_t>(SPIRV::ExecutionModel::Kernel))
248 .addUse(FuncVReg);
249 addStringImm(F.getName(), MIB);
250 } else if (F.getLinkage() == GlobalValue::LinkageTypes::ExternalLinkage ||
251 F.getLinkage() == GlobalValue::LinkOnceODRLinkage) {
252 auto LnkTy = F.isDeclaration() ? SPIRV::LinkageType::Import
253 : SPIRV::LinkageType::Export;
254 buildOpDecorate(FuncVReg, MIRBuilder, SPIRV::Decoration::LinkageAttributes,
255 {static_cast<uint32_t>(LnkTy)}, F.getGlobalIdentifier());
256 }
257
258 return true;
259 }
260
lowerCall(MachineIRBuilder & MIRBuilder,CallLoweringInfo & Info) const261 bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
262 CallLoweringInfo &Info) const {
263 // Currently call returns should have single vregs.
264 // TODO: handle the case of multiple registers.
265 if (Info.OrigRet.Regs.size() > 1)
266 return false;
267 MachineFunction &MF = MIRBuilder.getMF();
268 GR->setCurrentFunc(MF);
269 FunctionType *FTy = nullptr;
270 const Function *CF = nullptr;
271
272 // Emit a regular OpFunctionCall. If it's an externally declared function,
273 // be sure to emit its type and function declaration here. It will be hoisted
274 // globally later.
275 if (Info.Callee.isGlobal()) {
276 CF = dyn_cast_or_null<const Function>(Info.Callee.getGlobal());
277 // TODO: support constexpr casts and indirect calls.
278 if (CF == nullptr)
279 return false;
280 FTy = getOriginalFunctionType(*CF);
281 }
282
283 Register ResVReg =
284 Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0];
285 if (CF && CF->isDeclaration() &&
286 !GR->find(CF, &MIRBuilder.getMF()).isValid()) {
287 // Emit the type info and forward function declaration to the first MBB
288 // to ensure VReg definition dependencies are valid across all MBBs.
289 MachineIRBuilder FirstBlockBuilder;
290 FirstBlockBuilder.setMF(MF);
291 FirstBlockBuilder.setMBB(*MF.getBlockNumbered(0));
292
293 SmallVector<ArrayRef<Register>, 8> VRegArgs;
294 SmallVector<SmallVector<Register, 1>, 8> ToInsert;
295 for (const Argument &Arg : CF->args()) {
296 if (MIRBuilder.getDataLayout().getTypeStoreSize(Arg.getType()).isZero())
297 continue; // Don't handle zero sized types.
298 ToInsert.push_back(
299 {MIRBuilder.getMRI()->createGenericVirtualRegister(LLT::scalar(32))});
300 VRegArgs.push_back(ToInsert.back());
301 }
302 // TODO: Reuse FunctionLoweringInfo
303 FunctionLoweringInfo FuncInfo;
304 lowerFormalArguments(FirstBlockBuilder, *CF, VRegArgs, FuncInfo);
305 }
306
307 // Make sure there's a valid return reg, even for functions returning void.
308 if (!ResVReg.isValid())
309 ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
310 SPIRVType *RetType =
311 GR->assignTypeToVReg(FTy->getReturnType(), ResVReg, MIRBuilder);
312
313 // Emit the OpFunctionCall and its args.
314 auto MIB = MIRBuilder.buildInstr(SPIRV::OpFunctionCall)
315 .addDef(ResVReg)
316 .addUse(GR->getSPIRVTypeID(RetType))
317 .add(Info.Callee);
318
319 for (const auto &Arg : Info.OrigArgs) {
320 // Currently call args should have single vregs.
321 if (Arg.Regs.size() > 1)
322 return false;
323 MIB.addUse(Arg.Regs[0]);
324 }
325 const auto &STI = MF.getSubtarget();
326 return MIB.constrainAllUses(MIRBuilder.getTII(), *STI.getRegisterInfo(),
327 *STI.getRegBankInfo());
328 }
329