1d16eff81SEugene Zelenko //===- AMDGPURewriteOutArgumentsPass.cpp - Create struct returns ----------===//
2c06574ffSMatt Arsenault //
32946cd70SChandler Carruth // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42946cd70SChandler Carruth // See https://llvm.org/LICENSE.txt for license information.
52946cd70SChandler Carruth // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6c06574ffSMatt Arsenault //
7c06574ffSMatt Arsenault //===----------------------------------------------------------------------===//
8c06574ffSMatt Arsenault //
9c06574ffSMatt Arsenault /// \file This pass attempts to replace out argument usage with a return of a
10c06574ffSMatt Arsenault /// struct.
11c06574ffSMatt Arsenault ///
12c06574ffSMatt Arsenault /// We can support returning a lot of values directly in registers, but
13c06574ffSMatt Arsenault /// idiomatic C code frequently uses a pointer argument to return a second value
14c06574ffSMatt Arsenault /// rather than returning a struct by value. GPU stack access is also quite
15c06574ffSMatt Arsenault /// painful, so we want to avoid that if possible. Passing a stack object
16c06574ffSMatt Arsenault /// pointer to a function also requires an additional address expansion code
17c06574ffSMatt Arsenault /// sequence to convert the pointer to be relative to the kernel's scratch wave
18c06574ffSMatt Arsenault /// offset register since the callee doesn't know what stack frame the incoming
19c06574ffSMatt Arsenault /// pointer is relative to.
20c06574ffSMatt Arsenault ///
21c06574ffSMatt Arsenault /// The goal is to try rewriting code that looks like this:
22c06574ffSMatt Arsenault ///
23c06574ffSMatt Arsenault /// int foo(int a, int b, int* out) {
24c06574ffSMatt Arsenault /// *out = bar();
25c06574ffSMatt Arsenault /// return a + b;
26c06574ffSMatt Arsenault /// }
27c06574ffSMatt Arsenault ///
28c06574ffSMatt Arsenault /// into something like this:
29c06574ffSMatt Arsenault ///
30c06574ffSMatt Arsenault /// std::pair<int, int> foo(int a, int b) {
31c06574ffSMatt Arsenault /// return std::make_pair(a + b, bar());
32c06574ffSMatt Arsenault /// }
33c06574ffSMatt Arsenault ///
34c06574ffSMatt Arsenault /// Typically the incoming pointer is a simple alloca for a temporary variable
35c06574ffSMatt Arsenault /// to use the API, which if replaced with a struct return will be easily SROA'd
36c06574ffSMatt Arsenault /// out when the stub function we create is inlined
37c06574ffSMatt Arsenault ///
38c06574ffSMatt Arsenault /// This pass introduces the struct return, but leaves the unused pointer
39c06574ffSMatt Arsenault /// arguments and introduces a new stub function calling the struct returning
40c06574ffSMatt Arsenault /// body. DeadArgumentElimination should be run after this to clean these up.
41c06574ffSMatt Arsenault //
42c06574ffSMatt Arsenault //===----------------------------------------------------------------------===//
43c06574ffSMatt Arsenault
44c06574ffSMatt Arsenault #include "AMDGPU.h"
45c06574ffSMatt Arsenault #include "Utils/AMDGPUBaseInfo.h"
46d16eff81SEugene Zelenko #include "llvm/ADT/SmallSet.h"
47c06574ffSMatt Arsenault #include "llvm/ADT/Statistic.h"
4805da2fe5SReid Kleckner #include "llvm/Analysis/MemoryDependenceAnalysis.h"
49c06574ffSMatt Arsenault #include "llvm/IR/IRBuilder.h"
50d16eff81SEugene Zelenko #include "llvm/IR/Instructions.h"
5105da2fe5SReid Kleckner #include "llvm/InitializePasses.h"
52d16eff81SEugene Zelenko #include "llvm/Pass.h"
53d16eff81SEugene Zelenko #include "llvm/Support/CommandLine.h"
54c06574ffSMatt Arsenault #include "llvm/Support/Debug.h"
55d16eff81SEugene Zelenko #include "llvm/Support/raw_ostream.h"
56c06574ffSMatt Arsenault
57c06574ffSMatt Arsenault #define DEBUG_TYPE "amdgpu-rewrite-out-arguments"
58c06574ffSMatt Arsenault
59c06574ffSMatt Arsenault using namespace llvm;
60c06574ffSMatt Arsenault
61c06574ffSMatt Arsenault static cl::opt<bool> AnyAddressSpace(
62c06574ffSMatt Arsenault "amdgpu-any-address-space-out-arguments",
63c06574ffSMatt Arsenault cl::desc("Replace pointer out arguments with "
64c06574ffSMatt Arsenault "struct returns for non-private address space"),
65c06574ffSMatt Arsenault cl::Hidden,
66c06574ffSMatt Arsenault cl::init(false));
67c06574ffSMatt Arsenault
68c06574ffSMatt Arsenault static cl::opt<unsigned> MaxNumRetRegs(
69c06574ffSMatt Arsenault "amdgpu-max-return-arg-num-regs",
70c06574ffSMatt Arsenault cl::desc("Approximately limit number of return registers for replacing out arguments"),
71c06574ffSMatt Arsenault cl::Hidden,
72c06574ffSMatt Arsenault cl::init(16));
73c06574ffSMatt Arsenault
74c06574ffSMatt Arsenault STATISTIC(NumOutArgumentsReplaced,
75c06574ffSMatt Arsenault "Number out arguments moved to struct return values");
76c06574ffSMatt Arsenault STATISTIC(NumOutArgumentFunctionsReplaced,
77c06574ffSMatt Arsenault "Number of functions with out arguments moved to struct return values");
78c06574ffSMatt Arsenault
79d16eff81SEugene Zelenko namespace {
80d16eff81SEugene Zelenko
81c06574ffSMatt Arsenault class AMDGPURewriteOutArguments : public FunctionPass {
82c06574ffSMatt Arsenault private:
83c06574ffSMatt Arsenault const DataLayout *DL = nullptr;
84c06574ffSMatt Arsenault MemoryDependenceResults *MDA = nullptr;
85c06574ffSMatt Arsenault
86*99702734SNikita Popov Type *getStoredType(Value &Arg) const;
87*99702734SNikita Popov Type *getOutArgumentType(Argument &Arg) const;
88d16eff81SEugene Zelenko
89c06574ffSMatt Arsenault public:
90c06574ffSMatt Arsenault static char ID;
91c06574ffSMatt Arsenault
AMDGPURewriteOutArguments()92d16eff81SEugene Zelenko AMDGPURewriteOutArguments() : FunctionPass(ID) {}
93c06574ffSMatt Arsenault
getAnalysisUsage(AnalysisUsage & AU) const94c06574ffSMatt Arsenault void getAnalysisUsage(AnalysisUsage &AU) const override {
95c06574ffSMatt Arsenault AU.addRequired<MemoryDependenceWrapperPass>();
96c06574ffSMatt Arsenault FunctionPass::getAnalysisUsage(AU);
97c06574ffSMatt Arsenault }
98c06574ffSMatt Arsenault
99c06574ffSMatt Arsenault bool doInitialization(Module &M) override;
100d16eff81SEugene Zelenko bool runOnFunction(Function &F) override;
101c06574ffSMatt Arsenault };
102c06574ffSMatt Arsenault
103d16eff81SEugene Zelenko } // end anonymous namespace
104c06574ffSMatt Arsenault
105c06574ffSMatt Arsenault INITIALIZE_PASS_BEGIN(AMDGPURewriteOutArguments, DEBUG_TYPE,
106c06574ffSMatt Arsenault "AMDGPU Rewrite Out Arguments", false, false)
107c06574ffSMatt Arsenault INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass)
108c06574ffSMatt Arsenault INITIALIZE_PASS_END(AMDGPURewriteOutArguments, DEBUG_TYPE,
109c06574ffSMatt Arsenault "AMDGPU Rewrite Out Arguments", false, false)
110c06574ffSMatt Arsenault
111c06574ffSMatt Arsenault char AMDGPURewriteOutArguments::ID = 0;
112c06574ffSMatt Arsenault
getStoredType(Value & Arg) const113*99702734SNikita Popov Type *AMDGPURewriteOutArguments::getStoredType(Value &Arg) const {
114c06574ffSMatt Arsenault const int MaxUses = 10;
115c06574ffSMatt Arsenault int UseCount = 0;
116c06574ffSMatt Arsenault
117*99702734SNikita Popov SmallVector<Use *> Worklist;
118*99702734SNikita Popov for (Use &U : Arg.uses())
119*99702734SNikita Popov Worklist.push_back(&U);
120c06574ffSMatt Arsenault
121*99702734SNikita Popov Type *StoredType = nullptr;
122*99702734SNikita Popov while (!Worklist.empty()) {
123*99702734SNikita Popov Use *U = Worklist.pop_back_val();
124da9ab148SMatt Arsenault
125*99702734SNikita Popov if (auto *BCI = dyn_cast<BitCastInst>(U->getUser())) {
126*99702734SNikita Popov for (Use &U : BCI->uses())
127*99702734SNikita Popov Worklist.push_back(&U);
128*99702734SNikita Popov continue;
129da9ab148SMatt Arsenault }
130da9ab148SMatt Arsenault
131*99702734SNikita Popov if (auto *SI = dyn_cast<StoreInst>(U->getUser())) {
132*99702734SNikita Popov if (UseCount++ > MaxUses)
133*99702734SNikita Popov return nullptr;
134*99702734SNikita Popov
135da9ab148SMatt Arsenault if (!SI->isSimple() ||
136*99702734SNikita Popov U->getOperandNo() != StoreInst::getPointerOperandIndex())
137*99702734SNikita Popov return nullptr;
138c06574ffSMatt Arsenault
139*99702734SNikita Popov if (StoredType && StoredType != SI->getValueOperand()->getType())
140*99702734SNikita Popov return nullptr; // More than one type.
141*99702734SNikita Popov StoredType = SI->getValueOperand()->getType();
142*99702734SNikita Popov continue;
143c06574ffSMatt Arsenault }
144c06574ffSMatt Arsenault
145*99702734SNikita Popov // Unsupported user.
146*99702734SNikita Popov return nullptr;
147c06574ffSMatt Arsenault }
148c06574ffSMatt Arsenault
149*99702734SNikita Popov return StoredType;
150*99702734SNikita Popov }
151*99702734SNikita Popov
getOutArgumentType(Argument & Arg) const152*99702734SNikita Popov Type *AMDGPURewriteOutArguments::getOutArgumentType(Argument &Arg) const {
153da9ab148SMatt Arsenault const unsigned MaxOutArgSizeBytes = 4 * MaxNumRetRegs;
154da9ab148SMatt Arsenault PointerType *ArgTy = dyn_cast<PointerType>(Arg.getType());
155da9ab148SMatt Arsenault
156da9ab148SMatt Arsenault // TODO: It might be useful for any out arguments, not just privates.
157da9ab148SMatt Arsenault if (!ArgTy || (ArgTy->getAddressSpace() != DL->getAllocaAddrSpace() &&
158da9ab148SMatt Arsenault !AnyAddressSpace) ||
159*99702734SNikita Popov Arg.hasByValAttr() || Arg.hasStructRetAttr()) {
160*99702734SNikita Popov return nullptr;
161da9ab148SMatt Arsenault }
162da9ab148SMatt Arsenault
163*99702734SNikita Popov Type *StoredType = getStoredType(Arg);
164*99702734SNikita Popov if (!StoredType || DL->getTypeStoreSize(StoredType) > MaxOutArgSizeBytes)
165*99702734SNikita Popov return nullptr;
166*99702734SNikita Popov
167*99702734SNikita Popov return StoredType;
168da9ab148SMatt Arsenault }
169da9ab148SMatt Arsenault
doInitialization(Module & M)170c06574ffSMatt Arsenault bool AMDGPURewriteOutArguments::doInitialization(Module &M) {
171c06574ffSMatt Arsenault DL = &M.getDataLayout();
172c06574ffSMatt Arsenault return false;
173c06574ffSMatt Arsenault }
174c06574ffSMatt Arsenault
runOnFunction(Function & F)175c06574ffSMatt Arsenault bool AMDGPURewriteOutArguments::runOnFunction(Function &F) {
176c06574ffSMatt Arsenault if (skipFunction(F))
177c06574ffSMatt Arsenault return false;
178c06574ffSMatt Arsenault
179c06574ffSMatt Arsenault // TODO: Could probably handle variadic functions.
180c06574ffSMatt Arsenault if (F.isVarArg() || F.hasStructRetAttr() ||
181c06574ffSMatt Arsenault AMDGPU::isEntryFunctionCC(F.getCallingConv()))
182c06574ffSMatt Arsenault return false;
183c06574ffSMatt Arsenault
184c06574ffSMatt Arsenault MDA = &getAnalysis<MemoryDependenceWrapperPass>().getMemDep();
185c06574ffSMatt Arsenault
186c06574ffSMatt Arsenault unsigned ReturnNumRegs = 0;
187*99702734SNikita Popov SmallDenseMap<int, Type *, 4> OutArgIndexes;
188c06574ffSMatt Arsenault SmallVector<Type *, 4> ReturnTypes;
189c06574ffSMatt Arsenault Type *RetTy = F.getReturnType();
190c06574ffSMatt Arsenault if (!RetTy->isVoidTy()) {
191c06574ffSMatt Arsenault ReturnNumRegs = DL->getTypeStoreSize(RetTy) / 4;
192c06574ffSMatt Arsenault
193c06574ffSMatt Arsenault if (ReturnNumRegs >= MaxNumRetRegs)
194c06574ffSMatt Arsenault return false;
195c06574ffSMatt Arsenault
196c06574ffSMatt Arsenault ReturnTypes.push_back(RetTy);
197c06574ffSMatt Arsenault }
198c06574ffSMatt Arsenault
199*99702734SNikita Popov SmallVector<std::pair<Argument *, Type *>, 4> OutArgs;
200c06574ffSMatt Arsenault for (Argument &Arg : F.args()) {
201*99702734SNikita Popov if (Type *Ty = getOutArgumentType(Arg)) {
202d34e60caSNicola Zaghen LLVM_DEBUG(dbgs() << "Found possible out argument " << Arg
203c06574ffSMatt Arsenault << " in function " << F.getName() << '\n');
204*99702734SNikita Popov OutArgs.push_back({&Arg, Ty});
205c06574ffSMatt Arsenault }
206c06574ffSMatt Arsenault }
207c06574ffSMatt Arsenault
208c06574ffSMatt Arsenault if (OutArgs.empty())
209c06574ffSMatt Arsenault return false;
210c06574ffSMatt Arsenault
211d16eff81SEugene Zelenko using ReplacementVec = SmallVector<std::pair<Argument *, Value *>, 4>;
212d16eff81SEugene Zelenko
213c06574ffSMatt Arsenault DenseMap<ReturnInst *, ReplacementVec> Replacements;
214c06574ffSMatt Arsenault
215c06574ffSMatt Arsenault SmallVector<ReturnInst *, 4> Returns;
216c06574ffSMatt Arsenault for (BasicBlock &BB : F) {
217c06574ffSMatt Arsenault if (ReturnInst *RI = dyn_cast<ReturnInst>(&BB.back()))
218c06574ffSMatt Arsenault Returns.push_back(RI);
219c06574ffSMatt Arsenault }
220c06574ffSMatt Arsenault
221c06574ffSMatt Arsenault if (Returns.empty())
222c06574ffSMatt Arsenault return false;
223c06574ffSMatt Arsenault
224c06574ffSMatt Arsenault bool Changing;
225c06574ffSMatt Arsenault
226c06574ffSMatt Arsenault do {
227c06574ffSMatt Arsenault Changing = false;
228c06574ffSMatt Arsenault
229c06574ffSMatt Arsenault // Keep retrying if we are able to successfully eliminate an argument. This
230c06574ffSMatt Arsenault // helps with cases with multiple arguments which may alias, such as in a
231dc6e8dfdSJacob Lambert // sincos implementation. If we have 2 stores to arguments, on the first
232c06574ffSMatt Arsenault // attempt the MDA query will succeed for the second store but not the
233c06574ffSMatt Arsenault // first. On the second iteration we've removed that out clobbering argument
234c06574ffSMatt Arsenault // (by effectively moving it into another function) and will find the second
235c06574ffSMatt Arsenault // argument is OK to move.
236*99702734SNikita Popov for (const auto &Pair : OutArgs) {
237c06574ffSMatt Arsenault bool ThisReplaceable = true;
238c06574ffSMatt Arsenault SmallVector<std::pair<ReturnInst *, StoreInst *>, 4> ReplaceableStores;
239c06574ffSMatt Arsenault
240*99702734SNikita Popov Argument *OutArg = Pair.first;
241*99702734SNikita Popov Type *ArgTy = Pair.second;
242c06574ffSMatt Arsenault
243c06574ffSMatt Arsenault // Skip this argument if converting it will push us over the register
244c06574ffSMatt Arsenault // count to return limit.
245c06574ffSMatt Arsenault
246c06574ffSMatt Arsenault // TODO: This is an approximation. When legalized this could be more. We
247c06574ffSMatt Arsenault // can ask TLI for exactly how many.
248c06574ffSMatt Arsenault unsigned ArgNumRegs = DL->getTypeStoreSize(ArgTy) / 4;
249c06574ffSMatt Arsenault if (ArgNumRegs + ReturnNumRegs > MaxNumRetRegs)
250c06574ffSMatt Arsenault continue;
251c06574ffSMatt Arsenault
252c06574ffSMatt Arsenault // An argument is convertible only if all exit blocks are able to replace
253c06574ffSMatt Arsenault // it.
254c06574ffSMatt Arsenault for (ReturnInst *RI : Returns) {
255c06574ffSMatt Arsenault BasicBlock *BB = RI->getParent();
256c06574ffSMatt Arsenault
257393b9e9dSNikita Popov MemDepResult Q = MDA->getPointerDependencyFrom(
2584df8efceSNikita Popov MemoryLocation::getBeforeOrAfter(OutArg), true, BB->end(), BB, RI);
259c06574ffSMatt Arsenault StoreInst *SI = nullptr;
260c06574ffSMatt Arsenault if (Q.isDef())
261c06574ffSMatt Arsenault SI = dyn_cast<StoreInst>(Q.getInst());
262c06574ffSMatt Arsenault
263c06574ffSMatt Arsenault if (SI) {
264d34e60caSNicola Zaghen LLVM_DEBUG(dbgs() << "Found out argument store: " << *SI << '\n');
265c06574ffSMatt Arsenault ReplaceableStores.emplace_back(RI, SI);
266c06574ffSMatt Arsenault } else {
267c06574ffSMatt Arsenault ThisReplaceable = false;
268c06574ffSMatt Arsenault break;
269c06574ffSMatt Arsenault }
270c06574ffSMatt Arsenault }
271c06574ffSMatt Arsenault
272c06574ffSMatt Arsenault if (!ThisReplaceable)
273c06574ffSMatt Arsenault continue; // Try the next argument candidate.
274c06574ffSMatt Arsenault
275c06574ffSMatt Arsenault for (std::pair<ReturnInst *, StoreInst *> Store : ReplaceableStores) {
276c06574ffSMatt Arsenault Value *ReplVal = Store.second->getValueOperand();
277c06574ffSMatt Arsenault
278c06574ffSMatt Arsenault auto &ValVec = Replacements[Store.first];
279d6ff5cf9SKazu Hirata if (llvm::any_of(ValVec,
280c06574ffSMatt Arsenault [OutArg](const std::pair<Argument *, Value *> &Entry) {
281d6ff5cf9SKazu Hirata return Entry.first == OutArg;
282d6ff5cf9SKazu Hirata })) {
283d34e60caSNicola Zaghen LLVM_DEBUG(dbgs()
284d34e60caSNicola Zaghen << "Saw multiple out arg stores" << *OutArg << '\n');
285c06574ffSMatt Arsenault // It is possible to see stores to the same argument multiple times,
286c06574ffSMatt Arsenault // but we expect these would have been optimized out already.
287c06574ffSMatt Arsenault ThisReplaceable = false;
288c06574ffSMatt Arsenault break;
289c06574ffSMatt Arsenault }
290c06574ffSMatt Arsenault
291c06574ffSMatt Arsenault ValVec.emplace_back(OutArg, ReplVal);
292c06574ffSMatt Arsenault Store.second->eraseFromParent();
293c06574ffSMatt Arsenault }
294c06574ffSMatt Arsenault
295c06574ffSMatt Arsenault if (ThisReplaceable) {
296c06574ffSMatt Arsenault ReturnTypes.push_back(ArgTy);
297*99702734SNikita Popov OutArgIndexes.insert({OutArg->getArgNo(), ArgTy});
298c06574ffSMatt Arsenault ++NumOutArgumentsReplaced;
299c06574ffSMatt Arsenault Changing = true;
300c06574ffSMatt Arsenault }
301c06574ffSMatt Arsenault }
302c06574ffSMatt Arsenault } while (Changing);
303c06574ffSMatt Arsenault
304c06574ffSMatt Arsenault if (Replacements.empty())
305c06574ffSMatt Arsenault return false;
306c06574ffSMatt Arsenault
307c06574ffSMatt Arsenault LLVMContext &Ctx = F.getParent()->getContext();
308c06574ffSMatt Arsenault StructType *NewRetTy = StructType::create(Ctx, ReturnTypes, F.getName());
309c06574ffSMatt Arsenault
310c06574ffSMatt Arsenault FunctionType *NewFuncTy = FunctionType::get(NewRetTy,
311c06574ffSMatt Arsenault F.getFunctionType()->params(),
312c06574ffSMatt Arsenault F.isVarArg());
313c06574ffSMatt Arsenault
314d34e60caSNicola Zaghen LLVM_DEBUG(dbgs() << "Computed new return type: " << *NewRetTy << '\n');
315c06574ffSMatt Arsenault
316c06574ffSMatt Arsenault Function *NewFunc = Function::Create(NewFuncTy, Function::PrivateLinkage,
317c06574ffSMatt Arsenault F.getName() + ".body");
318c06574ffSMatt Arsenault F.getParent()->getFunctionList().insert(F.getIterator(), NewFunc);
319c06574ffSMatt Arsenault NewFunc->copyAttributesFrom(&F);
320c06574ffSMatt Arsenault NewFunc->setComdat(F.getComdat());
321c06574ffSMatt Arsenault
322c06574ffSMatt Arsenault // We want to preserve the function and param attributes, but need to strip
323c06574ffSMatt Arsenault // off any return attributes, e.g. zeroext doesn't make sense with a struct.
324c06574ffSMatt Arsenault NewFunc->stealArgumentListFrom(F);
325c06574ffSMatt Arsenault
3269290ccc3Sserge-sans-paille AttributeMask RetAttrs;
327c06574ffSMatt Arsenault RetAttrs.addAttribute(Attribute::SExt);
328c06574ffSMatt Arsenault RetAttrs.addAttribute(Attribute::ZExt);
329c06574ffSMatt Arsenault RetAttrs.addAttribute(Attribute::NoAlias);
330ad727ab7SArthur Eubanks NewFunc->removeRetAttrs(RetAttrs);
331c06574ffSMatt Arsenault // TODO: How to preserve metadata?
332c06574ffSMatt Arsenault
333c06574ffSMatt Arsenault // Move the body of the function into the new rewritten function, and replace
334c06574ffSMatt Arsenault // this function with a stub.
335c06574ffSMatt Arsenault NewFunc->getBasicBlockList().splice(NewFunc->begin(), F.getBasicBlockList());
336c06574ffSMatt Arsenault
337c06574ffSMatt Arsenault for (std::pair<ReturnInst *, ReplacementVec> &Replacement : Replacements) {
338c06574ffSMatt Arsenault ReturnInst *RI = Replacement.first;
339c06574ffSMatt Arsenault IRBuilder<> B(RI);
340c06574ffSMatt Arsenault B.SetCurrentDebugLocation(RI->getDebugLoc());
341c06574ffSMatt Arsenault
342c06574ffSMatt Arsenault int RetIdx = 0;
343c06574ffSMatt Arsenault Value *NewRetVal = UndefValue::get(NewRetTy);
344c06574ffSMatt Arsenault
345c06574ffSMatt Arsenault Value *RetVal = RI->getReturnValue();
346c06574ffSMatt Arsenault if (RetVal)
347c06574ffSMatt Arsenault NewRetVal = B.CreateInsertValue(NewRetVal, RetVal, RetIdx++);
348c06574ffSMatt Arsenault
349*99702734SNikita Popov for (std::pair<Argument *, Value *> ReturnPoint : Replacement.second)
350*99702734SNikita Popov NewRetVal = B.CreateInsertValue(NewRetVal, ReturnPoint.second, RetIdx++);
351c06574ffSMatt Arsenault
352c06574ffSMatt Arsenault if (RetVal)
353c06574ffSMatt Arsenault RI->setOperand(0, NewRetVal);
354c06574ffSMatt Arsenault else {
355c06574ffSMatt Arsenault B.CreateRet(NewRetVal);
356c06574ffSMatt Arsenault RI->eraseFromParent();
357c06574ffSMatt Arsenault }
358c06574ffSMatt Arsenault }
359c06574ffSMatt Arsenault
360c06574ffSMatt Arsenault SmallVector<Value *, 16> StubCallArgs;
361c06574ffSMatt Arsenault for (Argument &Arg : F.args()) {
362c06574ffSMatt Arsenault if (OutArgIndexes.count(Arg.getArgNo())) {
363c06574ffSMatt Arsenault // It's easier to preserve the type of the argument list. We rely on
364c06574ffSMatt Arsenault // DeadArgumentElimination to take care of these.
365c06574ffSMatt Arsenault StubCallArgs.push_back(UndefValue::get(Arg.getType()));
366c06574ffSMatt Arsenault } else {
367c06574ffSMatt Arsenault StubCallArgs.push_back(&Arg);
368c06574ffSMatt Arsenault }
369c06574ffSMatt Arsenault }
370c06574ffSMatt Arsenault
371c06574ffSMatt Arsenault BasicBlock *StubBB = BasicBlock::Create(Ctx, "", &F);
372c06574ffSMatt Arsenault IRBuilder<> B(StubBB);
373c06574ffSMatt Arsenault CallInst *StubCall = B.CreateCall(NewFunc, StubCallArgs);
374c06574ffSMatt Arsenault
375c06574ffSMatt Arsenault int RetIdx = RetTy->isVoidTy() ? 0 : 1;
376c06574ffSMatt Arsenault for (Argument &Arg : F.args()) {
377c06574ffSMatt Arsenault if (!OutArgIndexes.count(Arg.getArgNo()))
378c06574ffSMatt Arsenault continue;
379c06574ffSMatt Arsenault
380da9ab148SMatt Arsenault PointerType *ArgType = cast<PointerType>(Arg.getType());
381da9ab148SMatt Arsenault
382*99702734SNikita Popov Type *EltTy = OutArgIndexes[Arg.getArgNo()];
38359f95222SGuillaume Chatelet const auto Align =
38459f95222SGuillaume Chatelet DL->getValueOrABITypeAlignment(Arg.getParamAlign(), EltTy);
385c06574ffSMatt Arsenault
386c06574ffSMatt Arsenault Value *Val = B.CreateExtractValue(StubCall, RetIdx++);
387da9ab148SMatt Arsenault Type *PtrTy = Val->getType()->getPointerTo(ArgType->getAddressSpace());
388da9ab148SMatt Arsenault
389da9ab148SMatt Arsenault // We can peek through bitcasts, so the type may not match.
390da9ab148SMatt Arsenault Value *PtrVal = B.CreateBitCast(&Arg, PtrTy);
391da9ab148SMatt Arsenault
392da9ab148SMatt Arsenault B.CreateAlignedStore(Val, PtrVal, Align);
393c06574ffSMatt Arsenault }
394c06574ffSMatt Arsenault
395c06574ffSMatt Arsenault if (!RetTy->isVoidTy()) {
396c06574ffSMatt Arsenault B.CreateRet(B.CreateExtractValue(StubCall, 0));
397c06574ffSMatt Arsenault } else {
398c06574ffSMatt Arsenault B.CreateRetVoid();
399c06574ffSMatt Arsenault }
400c06574ffSMatt Arsenault
401c06574ffSMatt Arsenault // The function is now a stub we want to inline.
402c06574ffSMatt Arsenault F.addFnAttr(Attribute::AlwaysInline);
403c06574ffSMatt Arsenault
404c06574ffSMatt Arsenault ++NumOutArgumentFunctionsReplaced;
405c06574ffSMatt Arsenault return true;
406c06574ffSMatt Arsenault }
407c06574ffSMatt Arsenault
createAMDGPURewriteOutArgumentsPass()408c06574ffSMatt Arsenault FunctionPass *llvm::createAMDGPURewriteOutArgumentsPass() {
409c06574ffSMatt Arsenault return new AMDGPURewriteOutArguments();
410c06574ffSMatt Arsenault }
411