1 //===-- AMDGPURewriteOutArgumentsPass.cpp - Create struct returns ---------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 /// \file This pass attempts to replace out argument usage with a return of a 11 /// struct. 12 /// 13 /// We can support returning a lot of values directly in registers, but 14 /// idiomatic C code frequently uses a pointer argument to return a second value 15 /// rather than returning a struct by value. GPU stack access is also quite 16 /// painful, so we want to avoid that if possible. Passing a stack object 17 /// pointer to a function also requires an additional address expansion code 18 /// sequence to convert the pointer to be relative to the kernel's scratch wave 19 /// offset register since the callee doesn't know what stack frame the incoming 20 /// pointer is relative to. 21 /// 22 /// The goal is to try rewriting code that looks like this: 23 /// 24 /// int foo(int a, int b, int* out) { 25 /// *out = bar(); 26 /// return a + b; 27 /// } 28 /// 29 /// into something like this: 30 /// 31 /// std::pair<int, int> foo(int a, int b) { 32 /// return std::make_pair(a + b, bar()); 33 /// } 34 /// 35 /// Typically the incoming pointer is a simple alloca for a temporary variable 36 /// to use the API, which if replaced with a struct return will be easily SROA'd 37 /// out when the stub function we create is inlined 38 /// 39 /// This pass introduces the struct return, but leaves the unused pointer 40 /// arguments and introduces a new stub function calling the struct returning 41 /// body. DeadArgumentElimination should be run after this to clean these up. 42 // 43 //===----------------------------------------------------------------------===// 44 45 #include "AMDGPU.h" 46 #include "Utils/AMDGPUBaseInfo.h" 47 48 #include "llvm/Analysis/MemoryDependenceAnalysis.h" 49 #include "llvm/ADT/BitVector.h" 50 #include "llvm/ADT/SetVector.h" 51 #include "llvm/ADT/Statistic.h" 52 #include "llvm/IR/IRBuilder.h" 53 #include "llvm/IR/Module.h" 54 #include "llvm/Transforms/Utils/Cloning.h" 55 #include "llvm/Support/Debug.h" 56 57 #define DEBUG_TYPE "amdgpu-rewrite-out-arguments" 58 59 using namespace llvm; 60 61 namespace { 62 63 static cl::opt<bool> AnyAddressSpace( 64 "amdgpu-any-address-space-out-arguments", 65 cl::desc("Replace pointer out arguments with " 66 "struct returns for non-private address space"), 67 cl::Hidden, 68 cl::init(false)); 69 70 static cl::opt<unsigned> MaxNumRetRegs( 71 "amdgpu-max-return-arg-num-regs", 72 cl::desc("Approximately limit number of return registers for replacing out arguments"), 73 cl::Hidden, 74 cl::init(16)); 75 76 STATISTIC(NumOutArgumentsReplaced, 77 "Number out arguments moved to struct return values"); 78 STATISTIC(NumOutArgumentFunctionsReplaced, 79 "Number of functions with out arguments moved to struct return values"); 80 81 class AMDGPURewriteOutArguments : public FunctionPass { 82 private: 83 const DataLayout *DL = nullptr; 84 MemoryDependenceResults *MDA = nullptr; 85 86 bool checkArgumentUses(Value &Arg) const; 87 bool isOutArgumentCandidate(Argument &Arg) const; 88 89 bool isVec3ToVec4Shuffle(Type *Ty0, Type* Ty1) const; 90 public: 91 static char ID; 92 93 AMDGPURewriteOutArguments() : 94 FunctionPass(ID) {} 95 96 void getAnalysisUsage(AnalysisUsage &AU) const override { 97 AU.addRequired<MemoryDependenceWrapperPass>(); 98 FunctionPass::getAnalysisUsage(AU); 99 } 100 101 bool doInitialization(Module &M) override; 102 bool runOnFunction(Function &M) override; 103 }; 104 105 } // End anonymous namespace 106 107 INITIALIZE_PASS_BEGIN(AMDGPURewriteOutArguments, DEBUG_TYPE, 108 "AMDGPU Rewrite Out Arguments", false, false) 109 INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass) 110 INITIALIZE_PASS_END(AMDGPURewriteOutArguments, DEBUG_TYPE, 111 "AMDGPU Rewrite Out Arguments", false, false) 112 113 char AMDGPURewriteOutArguments::ID = 0; 114 115 bool AMDGPURewriteOutArguments::checkArgumentUses(Value &Arg) const { 116 const int MaxUses = 10; 117 int UseCount = 0; 118 119 for (Use &U : Arg.uses()) { 120 StoreInst *SI = dyn_cast<StoreInst>(U.getUser()); 121 if (UseCount > MaxUses) 122 return false; 123 124 if (!SI) { 125 auto *BCI = dyn_cast<BitCastInst>(U.getUser()); 126 if (!BCI || !BCI->hasOneUse()) 127 return false; 128 129 // We don't handle multiple stores currently, so stores to aggregate 130 // pointers aren't worth the trouble since they are canonically split up. 131 Type *DestEltTy = BCI->getType()->getPointerElementType(); 132 if (DestEltTy->isAggregateType()) 133 return false; 134 135 // We could handle these if we had a convenient way to bitcast between 136 // them. 137 Type *SrcEltTy = Arg.getType()->getPointerElementType(); 138 if (SrcEltTy->isArrayTy()) 139 return false; 140 141 // Special case handle structs with single members. It is useful to handle 142 // some casts between structs and non-structs, but we can't bitcast 143 // directly between them. directly bitcast between them. Blender uses 144 // some casts that look like { <3 x float> }* to <4 x float>* 145 if ((SrcEltTy->isStructTy() && (SrcEltTy->getNumContainedTypes() != 1))) 146 return false; 147 148 // Clang emits OpenCL 3-vector type accesses with a bitcast to the 149 // equivalent 4-element vector and accesses that, and we're looking for 150 // this pointer cast. 151 if (DL->getTypeAllocSize(SrcEltTy) != DL->getTypeAllocSize(DestEltTy)) 152 return false; 153 154 return checkArgumentUses(*BCI); 155 } 156 157 if (!SI->isSimple() || 158 U.getOperandNo() != StoreInst::getPointerOperandIndex()) 159 return false; 160 161 ++UseCount; 162 } 163 164 // Skip unused arguments. 165 return UseCount > 0; 166 } 167 168 bool AMDGPURewriteOutArguments::isOutArgumentCandidate(Argument &Arg) const { 169 const unsigned MaxOutArgSizeBytes = 4 * MaxNumRetRegs; 170 PointerType *ArgTy = dyn_cast<PointerType>(Arg.getType()); 171 172 // TODO: It might be useful for any out arguments, not just privates. 173 if (!ArgTy || (ArgTy->getAddressSpace() != DL->getAllocaAddrSpace() && 174 !AnyAddressSpace) || 175 Arg.hasByValAttr() || Arg.hasStructRetAttr() || 176 DL->getTypeStoreSize(ArgTy->getPointerElementType()) > MaxOutArgSizeBytes) { 177 return false; 178 } 179 180 return checkArgumentUses(Arg); 181 } 182 183 bool AMDGPURewriteOutArguments::doInitialization(Module &M) { 184 DL = &M.getDataLayout(); 185 return false; 186 } 187 188 bool AMDGPURewriteOutArguments::isVec3ToVec4Shuffle(Type *Ty0, Type* Ty1) const { 189 VectorType *VT0 = dyn_cast<VectorType>(Ty0); 190 VectorType *VT1 = dyn_cast<VectorType>(Ty1); 191 if (!VT0 || !VT1) 192 return false; 193 194 if (VT0->getNumElements() != 3 || 195 VT1->getNumElements() != 4) 196 return false; 197 198 return DL->getTypeSizeInBits(VT0->getElementType()) == 199 DL->getTypeSizeInBits(VT1->getElementType()); 200 } 201 202 bool AMDGPURewriteOutArguments::runOnFunction(Function &F) { 203 if (skipFunction(F)) 204 return false; 205 206 // TODO: Could probably handle variadic functions. 207 if (F.isVarArg() || F.hasStructRetAttr() || 208 AMDGPU::isEntryFunctionCC(F.getCallingConv())) 209 return false; 210 211 MDA = &getAnalysis<MemoryDependenceWrapperPass>().getMemDep(); 212 213 unsigned ReturnNumRegs = 0; 214 SmallSet<int, 4> OutArgIndexes; 215 SmallVector<Type *, 4> ReturnTypes; 216 Type *RetTy = F.getReturnType(); 217 if (!RetTy->isVoidTy()) { 218 ReturnNumRegs = DL->getTypeStoreSize(RetTy) / 4; 219 220 if (ReturnNumRegs >= MaxNumRetRegs) 221 return false; 222 223 ReturnTypes.push_back(RetTy); 224 } 225 226 SmallVector<Argument *, 4> OutArgs; 227 for (Argument &Arg : F.args()) { 228 if (isOutArgumentCandidate(Arg)) { 229 DEBUG(dbgs() << "Found possible out argument " << Arg 230 << " in function " << F.getName() << '\n'); 231 OutArgs.push_back(&Arg); 232 } 233 } 234 235 if (OutArgs.empty()) 236 return false; 237 238 typedef SmallVector<std::pair<Argument *, Value *>, 4> ReplacementVec; 239 DenseMap<ReturnInst *, ReplacementVec> Replacements; 240 241 SmallVector<ReturnInst *, 4> Returns; 242 for (BasicBlock &BB : F) { 243 if (ReturnInst *RI = dyn_cast<ReturnInst>(&BB.back())) 244 Returns.push_back(RI); 245 } 246 247 if (Returns.empty()) 248 return false; 249 250 bool Changing; 251 252 do { 253 Changing = false; 254 255 // Keep retrying if we are able to successfully eliminate an argument. This 256 // helps with cases with multiple arguments which may alias, such as in a 257 // sincos implemntation. If we have 2 stores to arguments, on the first 258 // attempt the MDA query will succeed for the second store but not the 259 // first. On the second iteration we've removed that out clobbering argument 260 // (by effectively moving it into another function) and will find the second 261 // argument is OK to move. 262 for (Argument *OutArg : OutArgs) { 263 bool ThisReplaceable = true; 264 SmallVector<std::pair<ReturnInst *, StoreInst *>, 4> ReplaceableStores; 265 266 Type *ArgTy = OutArg->getType()->getPointerElementType(); 267 268 // Skip this argument if converting it will push us over the register 269 // count to return limit. 270 271 // TODO: This is an approximation. When legalized this could be more. We 272 // can ask TLI for exactly how many. 273 unsigned ArgNumRegs = DL->getTypeStoreSize(ArgTy) / 4; 274 if (ArgNumRegs + ReturnNumRegs > MaxNumRetRegs) 275 continue; 276 277 // An argument is convertible only if all exit blocks are able to replace 278 // it. 279 for (ReturnInst *RI : Returns) { 280 BasicBlock *BB = RI->getParent(); 281 282 MemDepResult Q = MDA->getPointerDependencyFrom(MemoryLocation(OutArg), 283 true, BB->end(), BB, RI); 284 StoreInst *SI = nullptr; 285 if (Q.isDef()) 286 SI = dyn_cast<StoreInst>(Q.getInst()); 287 288 if (SI) { 289 DEBUG(dbgs() << "Found out argument store: " << *SI << '\n'); 290 ReplaceableStores.emplace_back(RI, SI); 291 } else { 292 ThisReplaceable = false; 293 break; 294 } 295 } 296 297 if (!ThisReplaceable) 298 continue; // Try the next argument candidate. 299 300 for (std::pair<ReturnInst *, StoreInst *> Store : ReplaceableStores) { 301 Value *ReplVal = Store.second->getValueOperand(); 302 303 auto &ValVec = Replacements[Store.first]; 304 if (llvm::find_if(ValVec, 305 [OutArg](const std::pair<Argument *, Value *> &Entry) { 306 return Entry.first == OutArg;}) != ValVec.end()) { 307 DEBUG(dbgs() << "Saw multiple out arg stores" << *OutArg << '\n'); 308 // It is possible to see stores to the same argument multiple times, 309 // but we expect these would have been optimized out already. 310 ThisReplaceable = false; 311 break; 312 } 313 314 ValVec.emplace_back(OutArg, ReplVal); 315 Store.second->eraseFromParent(); 316 } 317 318 if (ThisReplaceable) { 319 ReturnTypes.push_back(ArgTy); 320 OutArgIndexes.insert(OutArg->getArgNo()); 321 ++NumOutArgumentsReplaced; 322 Changing = true; 323 } 324 } 325 } while (Changing); 326 327 if (Replacements.empty()) 328 return false; 329 330 LLVMContext &Ctx = F.getParent()->getContext(); 331 StructType *NewRetTy = StructType::create(Ctx, ReturnTypes, F.getName()); 332 333 FunctionType *NewFuncTy = FunctionType::get(NewRetTy, 334 F.getFunctionType()->params(), 335 F.isVarArg()); 336 337 DEBUG(dbgs() << "Computed new return type: " << *NewRetTy << '\n'); 338 339 Function *NewFunc = Function::Create(NewFuncTy, Function::PrivateLinkage, 340 F.getName() + ".body"); 341 F.getParent()->getFunctionList().insert(F.getIterator(), NewFunc); 342 NewFunc->copyAttributesFrom(&F); 343 NewFunc->setComdat(F.getComdat()); 344 345 // We want to preserve the function and param attributes, but need to strip 346 // off any return attributes, e.g. zeroext doesn't make sense with a struct. 347 NewFunc->stealArgumentListFrom(F); 348 349 AttrBuilder RetAttrs; 350 RetAttrs.addAttribute(Attribute::SExt); 351 RetAttrs.addAttribute(Attribute::ZExt); 352 RetAttrs.addAttribute(Attribute::NoAlias); 353 NewFunc->removeAttributes(AttributeList::ReturnIndex, RetAttrs); 354 // TODO: How to preserve metadata? 355 356 // Move the body of the function into the new rewritten function, and replace 357 // this function with a stub. 358 NewFunc->getBasicBlockList().splice(NewFunc->begin(), F.getBasicBlockList()); 359 360 for (std::pair<ReturnInst *, ReplacementVec> &Replacement : Replacements) { 361 ReturnInst *RI = Replacement.first; 362 IRBuilder<> B(RI); 363 B.SetCurrentDebugLocation(RI->getDebugLoc()); 364 365 int RetIdx = 0; 366 Value *NewRetVal = UndefValue::get(NewRetTy); 367 368 Value *RetVal = RI->getReturnValue(); 369 if (RetVal) 370 NewRetVal = B.CreateInsertValue(NewRetVal, RetVal, RetIdx++); 371 372 373 for (std::pair<Argument *, Value *> ReturnPoint : Replacement.second) { 374 Argument *Arg = ReturnPoint.first; 375 Value *Val = ReturnPoint.second; 376 Type *EltTy = Arg->getType()->getPointerElementType(); 377 if (Val->getType() != EltTy) { 378 Type *EffectiveEltTy = EltTy; 379 if (StructType *CT = dyn_cast<StructType>(EltTy)) { 380 assert(CT->getNumContainedTypes() == 1); 381 EffectiveEltTy = CT->getContainedType(0); 382 } 383 384 if (DL->getTypeSizeInBits(EffectiveEltTy) != 385 DL->getTypeSizeInBits(Val->getType())) { 386 assert(isVec3ToVec4Shuffle(EffectiveEltTy, Val->getType())); 387 Val = B.CreateShuffleVector(Val, UndefValue::get(Val->getType()), 388 { 0, 1, 2 }); 389 } 390 391 Val = B.CreateBitCast(Val, EffectiveEltTy); 392 393 // Re-create single element composite. 394 if (EltTy != EffectiveEltTy) 395 Val = B.CreateInsertValue(UndefValue::get(EltTy), Val, 0); 396 } 397 398 NewRetVal = B.CreateInsertValue(NewRetVal, Val, RetIdx++); 399 } 400 401 if (RetVal) 402 RI->setOperand(0, NewRetVal); 403 else { 404 B.CreateRet(NewRetVal); 405 RI->eraseFromParent(); 406 } 407 } 408 409 SmallVector<Value *, 16> StubCallArgs; 410 for (Argument &Arg : F.args()) { 411 if (OutArgIndexes.count(Arg.getArgNo())) { 412 // It's easier to preserve the type of the argument list. We rely on 413 // DeadArgumentElimination to take care of these. 414 StubCallArgs.push_back(UndefValue::get(Arg.getType())); 415 } else { 416 StubCallArgs.push_back(&Arg); 417 } 418 } 419 420 BasicBlock *StubBB = BasicBlock::Create(Ctx, "", &F); 421 IRBuilder<> B(StubBB); 422 CallInst *StubCall = B.CreateCall(NewFunc, StubCallArgs); 423 424 int RetIdx = RetTy->isVoidTy() ? 0 : 1; 425 for (Argument &Arg : F.args()) { 426 if (!OutArgIndexes.count(Arg.getArgNo())) 427 continue; 428 429 PointerType *ArgType = cast<PointerType>(Arg.getType()); 430 431 auto *EltTy = ArgType->getElementType(); 432 unsigned Align = Arg.getParamAlignment(); 433 if (Align == 0) 434 Align = DL->getABITypeAlignment(EltTy); 435 436 Value *Val = B.CreateExtractValue(StubCall, RetIdx++); 437 Type *PtrTy = Val->getType()->getPointerTo(ArgType->getAddressSpace()); 438 439 // We can peek through bitcasts, so the type may not match. 440 Value *PtrVal = B.CreateBitCast(&Arg, PtrTy); 441 442 B.CreateAlignedStore(Val, PtrVal, Align); 443 } 444 445 if (!RetTy->isVoidTy()) { 446 B.CreateRet(B.CreateExtractValue(StubCall, 0)); 447 } else { 448 B.CreateRetVoid(); 449 } 450 451 // The function is now a stub we want to inline. 452 F.addFnAttr(Attribute::AlwaysInline); 453 454 ++NumOutArgumentFunctionsReplaced; 455 return true; 456 } 457 458 FunctionPass *llvm::createAMDGPURewriteOutArgumentsPass() { 459 return new AMDGPURewriteOutArguments(); 460 } 461