1 //===-- AMDGPULowerKernelArguments.cpp ------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file This pass replaces accesses to kernel arguments with loads from 10 /// offsets from the kernarg base pointer. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "AMDGPU.h" 15 #include "AMDGPUSubtarget.h" 16 #include "AMDGPUTargetMachine.h" 17 #include "llvm/ADT/StringRef.h" 18 #include "llvm/Analysis/Loads.h" 19 #include "llvm/CodeGen/Passes.h" 20 #include "llvm/CodeGen/TargetPassConfig.h" 21 #include "llvm/IR/Attributes.h" 22 #include "llvm/IR/BasicBlock.h" 23 #include "llvm/IR/Constants.h" 24 #include "llvm/IR/DerivedTypes.h" 25 #include "llvm/IR/Function.h" 26 #include "llvm/IR/IRBuilder.h" 27 #include "llvm/IR/InstrTypes.h" 28 #include "llvm/IR/Instruction.h" 29 #include "llvm/IR/Instructions.h" 30 #include "llvm/IR/LLVMContext.h" 31 #include "llvm/IR/MDBuilder.h" 32 #include "llvm/IR/Metadata.h" 33 #include "llvm/IR/Operator.h" 34 #include "llvm/IR/Type.h" 35 #include "llvm/IR/Value.h" 36 #include "llvm/Pass.h" 37 #include "llvm/Support/Casting.h" 38 39 #define DEBUG_TYPE "amdgpu-lower-kernel-arguments" 40 41 using namespace llvm; 42 43 namespace { 44 45 class AMDGPULowerKernelArguments : public FunctionPass{ 46 public: 47 static char ID; 48 49 AMDGPULowerKernelArguments() : FunctionPass(ID) {} 50 51 bool runOnFunction(Function &F) override; 52 53 void getAnalysisUsage(AnalysisUsage &AU) const override { 54 AU.addRequired<TargetPassConfig>(); 55 AU.setPreservesAll(); 56 } 57 }; 58 59 } // end anonymous namespace 60 61 // skip allocas 62 static BasicBlock::iterator getInsertPt(BasicBlock &BB) { 63 BasicBlock::iterator InsPt = BB.getFirstInsertionPt(); 64 for (BasicBlock::iterator E = BB.end(); InsPt != E; ++InsPt) { 65 AllocaInst *AI = dyn_cast<AllocaInst>(&*InsPt); 66 67 // If this is a dynamic alloca, the value may depend on the loaded kernargs, 68 // so loads will need to be inserted before it. 69 if (!AI || !AI->isStaticAlloca()) 70 break; 71 } 72 73 return InsPt; 74 } 75 76 bool AMDGPULowerKernelArguments::runOnFunction(Function &F) { 77 CallingConv::ID CC = F.getCallingConv(); 78 if (CC != CallingConv::AMDGPU_KERNEL || F.arg_empty()) 79 return false; 80 81 auto &TPC = getAnalysis<TargetPassConfig>(); 82 83 const TargetMachine &TM = TPC.getTM<TargetMachine>(); 84 const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F); 85 LLVMContext &Ctx = F.getParent()->getContext(); 86 const DataLayout &DL = F.getParent()->getDataLayout(); 87 BasicBlock &EntryBlock = *F.begin(); 88 IRBuilder<> Builder(&*getInsertPt(EntryBlock)); 89 90 const Align KernArgBaseAlign(16); // FIXME: Increase if necessary 91 const uint64_t BaseOffset = ST.getExplicitKernelArgOffset(F); 92 93 Align MaxAlign; 94 // FIXME: Alignment is broken broken with explicit arg offset.; 95 const uint64_t TotalKernArgSize = ST.getKernArgSegmentSize(F, MaxAlign); 96 if (TotalKernArgSize == 0) 97 return false; 98 99 CallInst *KernArgSegment = 100 Builder.CreateIntrinsic(Intrinsic::amdgcn_kernarg_segment_ptr, {}, {}, 101 nullptr, F.getName() + ".kernarg.segment"); 102 103 KernArgSegment->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull); 104 KernArgSegment->addAttribute(AttributeList::ReturnIndex, 105 Attribute::getWithDereferenceableBytes(Ctx, TotalKernArgSize)); 106 107 unsigned AS = KernArgSegment->getType()->getPointerAddressSpace(); 108 uint64_t ExplicitArgOffset = 0; 109 110 for (Argument &Arg : F.args()) { 111 const bool IsByRef = Arg.hasByRefAttr(); 112 Type *ArgTy = IsByRef ? Arg.getParamByRefType() : Arg.getType(); 113 MaybeAlign ABITypeAlign = IsByRef ? Arg.getParamAlign() : None; 114 if (!ABITypeAlign) 115 ABITypeAlign = DL.getABITypeAlign(ArgTy); 116 117 uint64_t Size = DL.getTypeSizeInBits(ArgTy); 118 uint64_t AllocSize = DL.getTypeAllocSize(ArgTy); 119 120 uint64_t EltOffset = alignTo(ExplicitArgOffset, ABITypeAlign) + BaseOffset; 121 ExplicitArgOffset = alignTo(ExplicitArgOffset, ABITypeAlign) + AllocSize; 122 123 if (Arg.use_empty()) 124 continue; 125 126 // If this is byval, the loads are already explicit in the function. We just 127 // need to rewrite the pointer values. 128 if (IsByRef) { 129 Value *ArgOffsetPtr = Builder.CreateConstInBoundsGEP1_64( 130 Builder.getInt8Ty(), KernArgSegment, EltOffset, 131 Arg.getName() + ".byval.kernarg.offset"); 132 133 Value *CastOffsetPtr = Builder.CreatePointerBitCastOrAddrSpaceCast( 134 ArgOffsetPtr, Arg.getType()); 135 Arg.replaceAllUsesWith(CastOffsetPtr); 136 continue; 137 } 138 139 if (PointerType *PT = dyn_cast<PointerType>(ArgTy)) { 140 // FIXME: Hack. We rely on AssertZext to be able to fold DS addressing 141 // modes on SI to know the high bits are 0 so pointer adds don't wrap. We 142 // can't represent this with range metadata because it's only allowed for 143 // integer types. 144 if ((PT->getAddressSpace() == AMDGPUAS::LOCAL_ADDRESS || 145 PT->getAddressSpace() == AMDGPUAS::REGION_ADDRESS) && 146 !ST.hasUsableDSOffset()) 147 continue; 148 149 // FIXME: We can replace this with equivalent alias.scope/noalias 150 // metadata, but this appears to be a lot of work. 151 if (Arg.hasNoAliasAttr()) 152 continue; 153 } 154 155 auto *VT = dyn_cast<FixedVectorType>(ArgTy); 156 bool IsV3 = VT && VT->getNumElements() == 3; 157 bool DoShiftOpt = Size < 32 && !ArgTy->isAggregateType(); 158 159 VectorType *V4Ty = nullptr; 160 161 int64_t AlignDownOffset = alignDown(EltOffset, 4); 162 int64_t OffsetDiff = EltOffset - AlignDownOffset; 163 Align AdjustedAlign = commonAlignment( 164 KernArgBaseAlign, DoShiftOpt ? AlignDownOffset : EltOffset); 165 166 Value *ArgPtr; 167 Type *AdjustedArgTy; 168 if (DoShiftOpt) { // FIXME: Handle aggregate types 169 // Since we don't have sub-dword scalar loads, avoid doing an extload by 170 // loading earlier than the argument address, and extracting the relevant 171 // bits. 172 // 173 // Additionally widen any sub-dword load to i32 even if suitably aligned, 174 // so that CSE between different argument loads works easily. 175 ArgPtr = Builder.CreateConstInBoundsGEP1_64( 176 Builder.getInt8Ty(), KernArgSegment, AlignDownOffset, 177 Arg.getName() + ".kernarg.offset.align.down"); 178 AdjustedArgTy = Builder.getInt32Ty(); 179 } else { 180 ArgPtr = Builder.CreateConstInBoundsGEP1_64( 181 Builder.getInt8Ty(), KernArgSegment, EltOffset, 182 Arg.getName() + ".kernarg.offset"); 183 AdjustedArgTy = ArgTy; 184 } 185 186 if (IsV3 && Size >= 32) { 187 V4Ty = FixedVectorType::get(VT->getElementType(), 4); 188 // Use the hack that clang uses to avoid SelectionDAG ruining v3 loads 189 AdjustedArgTy = V4Ty; 190 } 191 192 ArgPtr = Builder.CreateBitCast(ArgPtr, AdjustedArgTy->getPointerTo(AS), 193 ArgPtr->getName() + ".cast"); 194 LoadInst *Load = 195 Builder.CreateAlignedLoad(AdjustedArgTy, ArgPtr, AdjustedAlign); 196 Load->setMetadata(LLVMContext::MD_invariant_load, MDNode::get(Ctx, {})); 197 198 MDBuilder MDB(Ctx); 199 200 if (isa<PointerType>(ArgTy)) { 201 if (Arg.hasNonNullAttr()) 202 Load->setMetadata(LLVMContext::MD_nonnull, MDNode::get(Ctx, {})); 203 204 uint64_t DerefBytes = Arg.getDereferenceableBytes(); 205 if (DerefBytes != 0) { 206 Load->setMetadata( 207 LLVMContext::MD_dereferenceable, 208 MDNode::get(Ctx, 209 MDB.createConstant( 210 ConstantInt::get(Builder.getInt64Ty(), DerefBytes)))); 211 } 212 213 uint64_t DerefOrNullBytes = Arg.getDereferenceableOrNullBytes(); 214 if (DerefOrNullBytes != 0) { 215 Load->setMetadata( 216 LLVMContext::MD_dereferenceable_or_null, 217 MDNode::get(Ctx, 218 MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(), 219 DerefOrNullBytes)))); 220 } 221 222 unsigned ParamAlign = Arg.getParamAlignment(); 223 if (ParamAlign != 0) { 224 Load->setMetadata( 225 LLVMContext::MD_align, 226 MDNode::get(Ctx, 227 MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(), 228 ParamAlign)))); 229 } 230 } 231 232 // TODO: Convert noalias arg to !noalias 233 234 if (DoShiftOpt) { 235 Value *ExtractBits = OffsetDiff == 0 ? 236 Load : Builder.CreateLShr(Load, OffsetDiff * 8); 237 238 IntegerType *ArgIntTy = Builder.getIntNTy(Size); 239 Value *Trunc = Builder.CreateTrunc(ExtractBits, ArgIntTy); 240 Value *NewVal = Builder.CreateBitCast(Trunc, ArgTy, 241 Arg.getName() + ".load"); 242 Arg.replaceAllUsesWith(NewVal); 243 } else if (IsV3) { 244 Value *Shuf = Builder.CreateShuffleVector(Load, UndefValue::get(V4Ty), 245 ArrayRef<int>{0, 1, 2}, 246 Arg.getName() + ".load"); 247 Arg.replaceAllUsesWith(Shuf); 248 } else { 249 Load->setName(Arg.getName() + ".load"); 250 Arg.replaceAllUsesWith(Load); 251 } 252 } 253 254 KernArgSegment->addAttribute( 255 AttributeList::ReturnIndex, 256 Attribute::getWithAlignment(Ctx, std::max(KernArgBaseAlign, MaxAlign))); 257 258 return true; 259 } 260 261 INITIALIZE_PASS_BEGIN(AMDGPULowerKernelArguments, DEBUG_TYPE, 262 "AMDGPU Lower Kernel Arguments", false, false) 263 INITIALIZE_PASS_END(AMDGPULowerKernelArguments, DEBUG_TYPE, "AMDGPU Lower Kernel Arguments", 264 false, false) 265 266 char AMDGPULowerKernelArguments::ID = 0; 267 268 FunctionPass *llvm::createAMDGPULowerKernelArgumentsPass() { 269 return new AMDGPULowerKernelArguments(); 270 } 271