1d16eff81SEugene Zelenko //===- AMDGPURewriteOutArgumentsPass.cpp - Create struct returns ----------===//
2c06574ffSMatt Arsenault //
32946cd70SChandler Carruth // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42946cd70SChandler Carruth // See https://llvm.org/LICENSE.txt for license information.
52946cd70SChandler Carruth // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6c06574ffSMatt Arsenault //
7c06574ffSMatt Arsenault //===----------------------------------------------------------------------===//
8c06574ffSMatt Arsenault //
9c06574ffSMatt Arsenault /// \file This pass attempts to replace out argument usage with a return of a
10c06574ffSMatt Arsenault /// struct.
11c06574ffSMatt Arsenault ///
12c06574ffSMatt Arsenault /// We can support returning a lot of values directly in registers, but
13c06574ffSMatt Arsenault /// idiomatic C code frequently uses a pointer argument to return a second value
14c06574ffSMatt Arsenault /// rather than returning a struct by value. GPU stack access is also quite
15c06574ffSMatt Arsenault /// painful, so we want to avoid that if possible. Passing a stack object
16c06574ffSMatt Arsenault /// pointer to a function also requires an additional address expansion code
17c06574ffSMatt Arsenault /// sequence to convert the pointer to be relative to the kernel's scratch wave
18c06574ffSMatt Arsenault /// offset register since the callee doesn't know what stack frame the incoming
19c06574ffSMatt Arsenault /// pointer is relative to.
20c06574ffSMatt Arsenault ///
21c06574ffSMatt Arsenault /// The goal is to try rewriting code that looks like this:
22c06574ffSMatt Arsenault ///
23c06574ffSMatt Arsenault ///  int foo(int a, int b, int* out) {
24c06574ffSMatt Arsenault ///     *out = bar();
25c06574ffSMatt Arsenault ///     return a + b;
26c06574ffSMatt Arsenault /// }
27c06574ffSMatt Arsenault ///
28c06574ffSMatt Arsenault /// into something like this:
29c06574ffSMatt Arsenault ///
30c06574ffSMatt Arsenault ///  std::pair<int, int> foo(int a, int b) {
31c06574ffSMatt Arsenault ///     return std::make_pair(a + b, bar());
32c06574ffSMatt Arsenault /// }
33c06574ffSMatt Arsenault ///
34c06574ffSMatt Arsenault /// Typically the incoming pointer is a simple alloca for a temporary variable
35c06574ffSMatt Arsenault /// to use the API, which if replaced with a struct return will be easily SROA'd
36c06574ffSMatt Arsenault /// out when the stub function we create is inlined
37c06574ffSMatt Arsenault ///
38c06574ffSMatt Arsenault /// This pass introduces the struct return, but leaves the unused pointer
39c06574ffSMatt Arsenault /// arguments and introduces a new stub function calling the struct returning
40c06574ffSMatt Arsenault /// body. DeadArgumentElimination should be run after this to clean these up.
41c06574ffSMatt Arsenault //
42c06574ffSMatt Arsenault //===----------------------------------------------------------------------===//
43c06574ffSMatt Arsenault 
44c06574ffSMatt Arsenault #include "AMDGPU.h"
45c06574ffSMatt Arsenault #include "Utils/AMDGPUBaseInfo.h"
46d16eff81SEugene Zelenko #include "llvm/ADT/SmallSet.h"
47c06574ffSMatt Arsenault #include "llvm/ADT/Statistic.h"
4805da2fe5SReid Kleckner #include "llvm/Analysis/MemoryDependenceAnalysis.h"
49c06574ffSMatt Arsenault #include "llvm/IR/IRBuilder.h"
50d16eff81SEugene Zelenko #include "llvm/IR/Instructions.h"
5105da2fe5SReid Kleckner #include "llvm/InitializePasses.h"
52d16eff81SEugene Zelenko #include "llvm/Pass.h"
53d16eff81SEugene Zelenko #include "llvm/Support/CommandLine.h"
54c06574ffSMatt Arsenault #include "llvm/Support/Debug.h"
55d16eff81SEugene Zelenko #include "llvm/Support/raw_ostream.h"
56c06574ffSMatt Arsenault 
57c06574ffSMatt Arsenault #define DEBUG_TYPE "amdgpu-rewrite-out-arguments"
58c06574ffSMatt Arsenault 
59c06574ffSMatt Arsenault using namespace llvm;
60c06574ffSMatt Arsenault 
61c06574ffSMatt Arsenault static cl::opt<bool> AnyAddressSpace(
62c06574ffSMatt Arsenault   "amdgpu-any-address-space-out-arguments",
63c06574ffSMatt Arsenault   cl::desc("Replace pointer out arguments with "
64c06574ffSMatt Arsenault            "struct returns for non-private address space"),
65c06574ffSMatt Arsenault   cl::Hidden,
66c06574ffSMatt Arsenault   cl::init(false));
67c06574ffSMatt Arsenault 
68c06574ffSMatt Arsenault static cl::opt<unsigned> MaxNumRetRegs(
69c06574ffSMatt Arsenault   "amdgpu-max-return-arg-num-regs",
70c06574ffSMatt Arsenault   cl::desc("Approximately limit number of return registers for replacing out arguments"),
71c06574ffSMatt Arsenault   cl::Hidden,
72c06574ffSMatt Arsenault   cl::init(16));
73c06574ffSMatt Arsenault 
74c06574ffSMatt Arsenault STATISTIC(NumOutArgumentsReplaced,
75c06574ffSMatt Arsenault           "Number out arguments moved to struct return values");
76c06574ffSMatt Arsenault STATISTIC(NumOutArgumentFunctionsReplaced,
77c06574ffSMatt Arsenault           "Number of functions with out arguments moved to struct return values");
78c06574ffSMatt Arsenault 
79d16eff81SEugene Zelenko namespace {
80d16eff81SEugene Zelenko 
81c06574ffSMatt Arsenault class AMDGPURewriteOutArguments : public FunctionPass {
82c06574ffSMatt Arsenault private:
83c06574ffSMatt Arsenault   const DataLayout *DL = nullptr;
84c06574ffSMatt Arsenault   MemoryDependenceResults *MDA = nullptr;
85c06574ffSMatt Arsenault 
86*99702734SNikita Popov   Type *getStoredType(Value &Arg) const;
87*99702734SNikita Popov   Type *getOutArgumentType(Argument &Arg) const;
88d16eff81SEugene Zelenko 
89c06574ffSMatt Arsenault public:
90c06574ffSMatt Arsenault   static char ID;
91c06574ffSMatt Arsenault 
AMDGPURewriteOutArguments()92d16eff81SEugene Zelenko   AMDGPURewriteOutArguments() : FunctionPass(ID) {}
93c06574ffSMatt Arsenault 
getAnalysisUsage(AnalysisUsage & AU) const94c06574ffSMatt Arsenault   void getAnalysisUsage(AnalysisUsage &AU) const override {
95c06574ffSMatt Arsenault     AU.addRequired<MemoryDependenceWrapperPass>();
96c06574ffSMatt Arsenault     FunctionPass::getAnalysisUsage(AU);
97c06574ffSMatt Arsenault   }
98c06574ffSMatt Arsenault 
99c06574ffSMatt Arsenault   bool doInitialization(Module &M) override;
100d16eff81SEugene Zelenko   bool runOnFunction(Function &F) override;
101c06574ffSMatt Arsenault };
102c06574ffSMatt Arsenault 
103d16eff81SEugene Zelenko } // end anonymous namespace
104c06574ffSMatt Arsenault 
105c06574ffSMatt Arsenault INITIALIZE_PASS_BEGIN(AMDGPURewriteOutArguments, DEBUG_TYPE,
106c06574ffSMatt Arsenault                       "AMDGPU Rewrite Out Arguments", false, false)
107c06574ffSMatt Arsenault INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass)
108c06574ffSMatt Arsenault INITIALIZE_PASS_END(AMDGPURewriteOutArguments, DEBUG_TYPE,
109c06574ffSMatt Arsenault                     "AMDGPU Rewrite Out Arguments", false, false)
110c06574ffSMatt Arsenault 
111c06574ffSMatt Arsenault char AMDGPURewriteOutArguments::ID = 0;
112c06574ffSMatt Arsenault 
getStoredType(Value & Arg) const113*99702734SNikita Popov Type *AMDGPURewriteOutArguments::getStoredType(Value &Arg) const {
114c06574ffSMatt Arsenault   const int MaxUses = 10;
115c06574ffSMatt Arsenault   int UseCount = 0;
116c06574ffSMatt Arsenault 
117*99702734SNikita Popov   SmallVector<Use *> Worklist;
118*99702734SNikita Popov   for (Use &U : Arg.uses())
119*99702734SNikita Popov     Worklist.push_back(&U);
120c06574ffSMatt Arsenault 
121*99702734SNikita Popov   Type *StoredType = nullptr;
122*99702734SNikita Popov   while (!Worklist.empty()) {
123*99702734SNikita Popov     Use *U = Worklist.pop_back_val();
124da9ab148SMatt Arsenault 
125*99702734SNikita Popov     if (auto *BCI = dyn_cast<BitCastInst>(U->getUser())) {
126*99702734SNikita Popov       for (Use &U : BCI->uses())
127*99702734SNikita Popov         Worklist.push_back(&U);
128*99702734SNikita Popov       continue;
129da9ab148SMatt Arsenault     }
130da9ab148SMatt Arsenault 
131*99702734SNikita Popov     if (auto *SI = dyn_cast<StoreInst>(U->getUser())) {
132*99702734SNikita Popov       if (UseCount++ > MaxUses)
133*99702734SNikita Popov         return nullptr;
134*99702734SNikita Popov 
135da9ab148SMatt Arsenault       if (!SI->isSimple() ||
136*99702734SNikita Popov           U->getOperandNo() != StoreInst::getPointerOperandIndex())
137*99702734SNikita Popov         return nullptr;
138c06574ffSMatt Arsenault 
139*99702734SNikita Popov       if (StoredType && StoredType != SI->getValueOperand()->getType())
140*99702734SNikita Popov         return nullptr; // More than one type.
141*99702734SNikita Popov       StoredType = SI->getValueOperand()->getType();
142*99702734SNikita Popov       continue;
143c06574ffSMatt Arsenault     }
144c06574ffSMatt Arsenault 
145*99702734SNikita Popov     // Unsupported user.
146*99702734SNikita Popov     return nullptr;
147c06574ffSMatt Arsenault   }
148c06574ffSMatt Arsenault 
149*99702734SNikita Popov   return StoredType;
150*99702734SNikita Popov }
151*99702734SNikita Popov 
getOutArgumentType(Argument & Arg) const152*99702734SNikita Popov Type *AMDGPURewriteOutArguments::getOutArgumentType(Argument &Arg) const {
153da9ab148SMatt Arsenault   const unsigned MaxOutArgSizeBytes = 4 * MaxNumRetRegs;
154da9ab148SMatt Arsenault   PointerType *ArgTy = dyn_cast<PointerType>(Arg.getType());
155da9ab148SMatt Arsenault 
156da9ab148SMatt Arsenault   // TODO: It might be useful for any out arguments, not just privates.
157da9ab148SMatt Arsenault   if (!ArgTy || (ArgTy->getAddressSpace() != DL->getAllocaAddrSpace() &&
158da9ab148SMatt Arsenault                  !AnyAddressSpace) ||
159*99702734SNikita Popov       Arg.hasByValAttr() || Arg.hasStructRetAttr()) {
160*99702734SNikita Popov     return nullptr;
161da9ab148SMatt Arsenault   }
162da9ab148SMatt Arsenault 
163*99702734SNikita Popov   Type *StoredType = getStoredType(Arg);
164*99702734SNikita Popov   if (!StoredType || DL->getTypeStoreSize(StoredType) > MaxOutArgSizeBytes)
165*99702734SNikita Popov     return nullptr;
166*99702734SNikita Popov 
167*99702734SNikita Popov   return StoredType;
168da9ab148SMatt Arsenault }
169da9ab148SMatt Arsenault 
doInitialization(Module & M)170c06574ffSMatt Arsenault bool AMDGPURewriteOutArguments::doInitialization(Module &M) {
171c06574ffSMatt Arsenault   DL = &M.getDataLayout();
172c06574ffSMatt Arsenault   return false;
173c06574ffSMatt Arsenault }
174c06574ffSMatt Arsenault 
runOnFunction(Function & F)175c06574ffSMatt Arsenault bool AMDGPURewriteOutArguments::runOnFunction(Function &F) {
176c06574ffSMatt Arsenault   if (skipFunction(F))
177c06574ffSMatt Arsenault     return false;
178c06574ffSMatt Arsenault 
179c06574ffSMatt Arsenault   // TODO: Could probably handle variadic functions.
180c06574ffSMatt Arsenault   if (F.isVarArg() || F.hasStructRetAttr() ||
181c06574ffSMatt Arsenault       AMDGPU::isEntryFunctionCC(F.getCallingConv()))
182c06574ffSMatt Arsenault     return false;
183c06574ffSMatt Arsenault 
184c06574ffSMatt Arsenault   MDA = &getAnalysis<MemoryDependenceWrapperPass>().getMemDep();
185c06574ffSMatt Arsenault 
186c06574ffSMatt Arsenault   unsigned ReturnNumRegs = 0;
187*99702734SNikita Popov   SmallDenseMap<int, Type *, 4> OutArgIndexes;
188c06574ffSMatt Arsenault   SmallVector<Type *, 4> ReturnTypes;
189c06574ffSMatt Arsenault   Type *RetTy = F.getReturnType();
190c06574ffSMatt Arsenault   if (!RetTy->isVoidTy()) {
191c06574ffSMatt Arsenault     ReturnNumRegs = DL->getTypeStoreSize(RetTy) / 4;
192c06574ffSMatt Arsenault 
193c06574ffSMatt Arsenault     if (ReturnNumRegs >= MaxNumRetRegs)
194c06574ffSMatt Arsenault       return false;
195c06574ffSMatt Arsenault 
196c06574ffSMatt Arsenault     ReturnTypes.push_back(RetTy);
197c06574ffSMatt Arsenault   }
198c06574ffSMatt Arsenault 
199*99702734SNikita Popov   SmallVector<std::pair<Argument *, Type *>, 4> OutArgs;
200c06574ffSMatt Arsenault   for (Argument &Arg : F.args()) {
201*99702734SNikita Popov     if (Type *Ty = getOutArgumentType(Arg)) {
202d34e60caSNicola Zaghen       LLVM_DEBUG(dbgs() << "Found possible out argument " << Arg
203c06574ffSMatt Arsenault                         << " in function " << F.getName() << '\n');
204*99702734SNikita Popov       OutArgs.push_back({&Arg, Ty});
205c06574ffSMatt Arsenault     }
206c06574ffSMatt Arsenault   }
207c06574ffSMatt Arsenault 
208c06574ffSMatt Arsenault   if (OutArgs.empty())
209c06574ffSMatt Arsenault     return false;
210c06574ffSMatt Arsenault 
211d16eff81SEugene Zelenko   using ReplacementVec = SmallVector<std::pair<Argument *, Value *>, 4>;
212d16eff81SEugene Zelenko 
213c06574ffSMatt Arsenault   DenseMap<ReturnInst *, ReplacementVec> Replacements;
214c06574ffSMatt Arsenault 
215c06574ffSMatt Arsenault   SmallVector<ReturnInst *, 4> Returns;
216c06574ffSMatt Arsenault   for (BasicBlock &BB : F) {
217c06574ffSMatt Arsenault     if (ReturnInst *RI = dyn_cast<ReturnInst>(&BB.back()))
218c06574ffSMatt Arsenault       Returns.push_back(RI);
219c06574ffSMatt Arsenault   }
220c06574ffSMatt Arsenault 
221c06574ffSMatt Arsenault   if (Returns.empty())
222c06574ffSMatt Arsenault     return false;
223c06574ffSMatt Arsenault 
224c06574ffSMatt Arsenault   bool Changing;
225c06574ffSMatt Arsenault 
226c06574ffSMatt Arsenault   do {
227c06574ffSMatt Arsenault     Changing = false;
228c06574ffSMatt Arsenault 
229c06574ffSMatt Arsenault     // Keep retrying if we are able to successfully eliminate an argument. This
230c06574ffSMatt Arsenault     // helps with cases with multiple arguments which may alias, such as in a
231dc6e8dfdSJacob Lambert     // sincos implementation. If we have 2 stores to arguments, on the first
232c06574ffSMatt Arsenault     // attempt the MDA query will succeed for the second store but not the
233c06574ffSMatt Arsenault     // first. On the second iteration we've removed that out clobbering argument
234c06574ffSMatt Arsenault     // (by effectively moving it into another function) and will find the second
235c06574ffSMatt Arsenault     // argument is OK to move.
236*99702734SNikita Popov     for (const auto &Pair : OutArgs) {
237c06574ffSMatt Arsenault       bool ThisReplaceable = true;
238c06574ffSMatt Arsenault       SmallVector<std::pair<ReturnInst *, StoreInst *>, 4> ReplaceableStores;
239c06574ffSMatt Arsenault 
240*99702734SNikita Popov       Argument *OutArg = Pair.first;
241*99702734SNikita Popov       Type *ArgTy = Pair.second;
242c06574ffSMatt Arsenault 
243c06574ffSMatt Arsenault       // Skip this argument if converting it will push us over the register
244c06574ffSMatt Arsenault       // count to return limit.
245c06574ffSMatt Arsenault 
246c06574ffSMatt Arsenault       // TODO: This is an approximation. When legalized this could be more. We
247c06574ffSMatt Arsenault       // can ask TLI for exactly how many.
248c06574ffSMatt Arsenault       unsigned ArgNumRegs = DL->getTypeStoreSize(ArgTy) / 4;
249c06574ffSMatt Arsenault       if (ArgNumRegs + ReturnNumRegs > MaxNumRetRegs)
250c06574ffSMatt Arsenault         continue;
251c06574ffSMatt Arsenault 
252c06574ffSMatt Arsenault       // An argument is convertible only if all exit blocks are able to replace
253c06574ffSMatt Arsenault       // it.
254c06574ffSMatt Arsenault       for (ReturnInst *RI : Returns) {
255c06574ffSMatt Arsenault         BasicBlock *BB = RI->getParent();
256c06574ffSMatt Arsenault 
257393b9e9dSNikita Popov         MemDepResult Q = MDA->getPointerDependencyFrom(
2584df8efceSNikita Popov             MemoryLocation::getBeforeOrAfter(OutArg), true, BB->end(), BB, RI);
259c06574ffSMatt Arsenault         StoreInst *SI = nullptr;
260c06574ffSMatt Arsenault         if (Q.isDef())
261c06574ffSMatt Arsenault           SI = dyn_cast<StoreInst>(Q.getInst());
262c06574ffSMatt Arsenault 
263c06574ffSMatt Arsenault         if (SI) {
264d34e60caSNicola Zaghen           LLVM_DEBUG(dbgs() << "Found out argument store: " << *SI << '\n');
265c06574ffSMatt Arsenault           ReplaceableStores.emplace_back(RI, SI);
266c06574ffSMatt Arsenault         } else {
267c06574ffSMatt Arsenault           ThisReplaceable = false;
268c06574ffSMatt Arsenault           break;
269c06574ffSMatt Arsenault         }
270c06574ffSMatt Arsenault       }
271c06574ffSMatt Arsenault 
272c06574ffSMatt Arsenault       if (!ThisReplaceable)
273c06574ffSMatt Arsenault         continue; // Try the next argument candidate.
274c06574ffSMatt Arsenault 
275c06574ffSMatt Arsenault       for (std::pair<ReturnInst *, StoreInst *> Store : ReplaceableStores) {
276c06574ffSMatt Arsenault         Value *ReplVal = Store.second->getValueOperand();
277c06574ffSMatt Arsenault 
278c06574ffSMatt Arsenault         auto &ValVec = Replacements[Store.first];
279d6ff5cf9SKazu Hirata         if (llvm::any_of(ValVec,
280c06574ffSMatt Arsenault                          [OutArg](const std::pair<Argument *, Value *> &Entry) {
281d6ff5cf9SKazu Hirata                            return Entry.first == OutArg;
282d6ff5cf9SKazu Hirata                          })) {
283d34e60caSNicola Zaghen           LLVM_DEBUG(dbgs()
284d34e60caSNicola Zaghen                      << "Saw multiple out arg stores" << *OutArg << '\n');
285c06574ffSMatt Arsenault           // It is possible to see stores to the same argument multiple times,
286c06574ffSMatt Arsenault           // but we expect these would have been optimized out already.
287c06574ffSMatt Arsenault           ThisReplaceable = false;
288c06574ffSMatt Arsenault           break;
289c06574ffSMatt Arsenault         }
290c06574ffSMatt Arsenault 
291c06574ffSMatt Arsenault         ValVec.emplace_back(OutArg, ReplVal);
292c06574ffSMatt Arsenault         Store.second->eraseFromParent();
293c06574ffSMatt Arsenault       }
294c06574ffSMatt Arsenault 
295c06574ffSMatt Arsenault       if (ThisReplaceable) {
296c06574ffSMatt Arsenault         ReturnTypes.push_back(ArgTy);
297*99702734SNikita Popov         OutArgIndexes.insert({OutArg->getArgNo(), ArgTy});
298c06574ffSMatt Arsenault         ++NumOutArgumentsReplaced;
299c06574ffSMatt Arsenault         Changing = true;
300c06574ffSMatt Arsenault       }
301c06574ffSMatt Arsenault     }
302c06574ffSMatt Arsenault   } while (Changing);
303c06574ffSMatt Arsenault 
304c06574ffSMatt Arsenault   if (Replacements.empty())
305c06574ffSMatt Arsenault     return false;
306c06574ffSMatt Arsenault 
307c06574ffSMatt Arsenault   LLVMContext &Ctx = F.getParent()->getContext();
308c06574ffSMatt Arsenault   StructType *NewRetTy = StructType::create(Ctx, ReturnTypes, F.getName());
309c06574ffSMatt Arsenault 
310c06574ffSMatt Arsenault   FunctionType *NewFuncTy = FunctionType::get(NewRetTy,
311c06574ffSMatt Arsenault                                               F.getFunctionType()->params(),
312c06574ffSMatt Arsenault                                               F.isVarArg());
313c06574ffSMatt Arsenault 
314d34e60caSNicola Zaghen   LLVM_DEBUG(dbgs() << "Computed new return type: " << *NewRetTy << '\n');
315c06574ffSMatt Arsenault 
316c06574ffSMatt Arsenault   Function *NewFunc = Function::Create(NewFuncTy, Function::PrivateLinkage,
317c06574ffSMatt Arsenault                                        F.getName() + ".body");
318c06574ffSMatt Arsenault   F.getParent()->getFunctionList().insert(F.getIterator(), NewFunc);
319c06574ffSMatt Arsenault   NewFunc->copyAttributesFrom(&F);
320c06574ffSMatt Arsenault   NewFunc->setComdat(F.getComdat());
321c06574ffSMatt Arsenault 
322c06574ffSMatt Arsenault   // We want to preserve the function and param attributes, but need to strip
323c06574ffSMatt Arsenault   // off any return attributes, e.g. zeroext doesn't make sense with a struct.
324c06574ffSMatt Arsenault   NewFunc->stealArgumentListFrom(F);
325c06574ffSMatt Arsenault 
3269290ccc3Sserge-sans-paille   AttributeMask RetAttrs;
327c06574ffSMatt Arsenault   RetAttrs.addAttribute(Attribute::SExt);
328c06574ffSMatt Arsenault   RetAttrs.addAttribute(Attribute::ZExt);
329c06574ffSMatt Arsenault   RetAttrs.addAttribute(Attribute::NoAlias);
330ad727ab7SArthur Eubanks   NewFunc->removeRetAttrs(RetAttrs);
331c06574ffSMatt Arsenault   // TODO: How to preserve metadata?
332c06574ffSMatt Arsenault 
333c06574ffSMatt Arsenault   // Move the body of the function into the new rewritten function, and replace
334c06574ffSMatt Arsenault   // this function with a stub.
335c06574ffSMatt Arsenault   NewFunc->getBasicBlockList().splice(NewFunc->begin(), F.getBasicBlockList());
336c06574ffSMatt Arsenault 
337c06574ffSMatt Arsenault   for (std::pair<ReturnInst *, ReplacementVec> &Replacement : Replacements) {
338c06574ffSMatt Arsenault     ReturnInst *RI = Replacement.first;
339c06574ffSMatt Arsenault     IRBuilder<> B(RI);
340c06574ffSMatt Arsenault     B.SetCurrentDebugLocation(RI->getDebugLoc());
341c06574ffSMatt Arsenault 
342c06574ffSMatt Arsenault     int RetIdx = 0;
343c06574ffSMatt Arsenault     Value *NewRetVal = UndefValue::get(NewRetTy);
344c06574ffSMatt Arsenault 
345c06574ffSMatt Arsenault     Value *RetVal = RI->getReturnValue();
346c06574ffSMatt Arsenault     if (RetVal)
347c06574ffSMatt Arsenault       NewRetVal = B.CreateInsertValue(NewRetVal, RetVal, RetIdx++);
348c06574ffSMatt Arsenault 
349*99702734SNikita Popov     for (std::pair<Argument *, Value *> ReturnPoint : Replacement.second)
350*99702734SNikita Popov       NewRetVal = B.CreateInsertValue(NewRetVal, ReturnPoint.second, RetIdx++);
351c06574ffSMatt Arsenault 
352c06574ffSMatt Arsenault     if (RetVal)
353c06574ffSMatt Arsenault       RI->setOperand(0, NewRetVal);
354c06574ffSMatt Arsenault     else {
355c06574ffSMatt Arsenault       B.CreateRet(NewRetVal);
356c06574ffSMatt Arsenault       RI->eraseFromParent();
357c06574ffSMatt Arsenault     }
358c06574ffSMatt Arsenault   }
359c06574ffSMatt Arsenault 
360c06574ffSMatt Arsenault   SmallVector<Value *, 16> StubCallArgs;
361c06574ffSMatt Arsenault   for (Argument &Arg : F.args()) {
362c06574ffSMatt Arsenault     if (OutArgIndexes.count(Arg.getArgNo())) {
363c06574ffSMatt Arsenault       // It's easier to preserve the type of the argument list. We rely on
364c06574ffSMatt Arsenault       // DeadArgumentElimination to take care of these.
365c06574ffSMatt Arsenault       StubCallArgs.push_back(UndefValue::get(Arg.getType()));
366c06574ffSMatt Arsenault     } else {
367c06574ffSMatt Arsenault       StubCallArgs.push_back(&Arg);
368c06574ffSMatt Arsenault     }
369c06574ffSMatt Arsenault   }
370c06574ffSMatt Arsenault 
371c06574ffSMatt Arsenault   BasicBlock *StubBB = BasicBlock::Create(Ctx, "", &F);
372c06574ffSMatt Arsenault   IRBuilder<> B(StubBB);
373c06574ffSMatt Arsenault   CallInst *StubCall = B.CreateCall(NewFunc, StubCallArgs);
374c06574ffSMatt Arsenault 
375c06574ffSMatt Arsenault   int RetIdx = RetTy->isVoidTy() ? 0 : 1;
376c06574ffSMatt Arsenault   for (Argument &Arg : F.args()) {
377c06574ffSMatt Arsenault     if (!OutArgIndexes.count(Arg.getArgNo()))
378c06574ffSMatt Arsenault       continue;
379c06574ffSMatt Arsenault 
380da9ab148SMatt Arsenault     PointerType *ArgType = cast<PointerType>(Arg.getType());
381da9ab148SMatt Arsenault 
382*99702734SNikita Popov     Type *EltTy = OutArgIndexes[Arg.getArgNo()];
38359f95222SGuillaume Chatelet     const auto Align =
38459f95222SGuillaume Chatelet         DL->getValueOrABITypeAlignment(Arg.getParamAlign(), EltTy);
385c06574ffSMatt Arsenault 
386c06574ffSMatt Arsenault     Value *Val = B.CreateExtractValue(StubCall, RetIdx++);
387da9ab148SMatt Arsenault     Type *PtrTy = Val->getType()->getPointerTo(ArgType->getAddressSpace());
388da9ab148SMatt Arsenault 
389da9ab148SMatt Arsenault     // We can peek through bitcasts, so the type may not match.
390da9ab148SMatt Arsenault     Value *PtrVal = B.CreateBitCast(&Arg, PtrTy);
391da9ab148SMatt Arsenault 
392da9ab148SMatt Arsenault     B.CreateAlignedStore(Val, PtrVal, Align);
393c06574ffSMatt Arsenault   }
394c06574ffSMatt Arsenault 
395c06574ffSMatt Arsenault   if (!RetTy->isVoidTy()) {
396c06574ffSMatt Arsenault     B.CreateRet(B.CreateExtractValue(StubCall, 0));
397c06574ffSMatt Arsenault   } else {
398c06574ffSMatt Arsenault     B.CreateRetVoid();
399c06574ffSMatt Arsenault   }
400c06574ffSMatt Arsenault 
401c06574ffSMatt Arsenault   // The function is now a stub we want to inline.
402c06574ffSMatt Arsenault   F.addFnAttr(Attribute::AlwaysInline);
403c06574ffSMatt Arsenault 
404c06574ffSMatt Arsenault   ++NumOutArgumentFunctionsReplaced;
405c06574ffSMatt Arsenault   return true;
406c06574ffSMatt Arsenault }
407c06574ffSMatt Arsenault 
createAMDGPURewriteOutArgumentsPass()408c06574ffSMatt Arsenault FunctionPass *llvm::createAMDGPURewriteOutArgumentsPass() {
409c06574ffSMatt Arsenault   return new AMDGPURewriteOutArguments();
410c06574ffSMatt Arsenault }
411