1 //===- AMDGPUUnifyDivergentExitNodes.cpp ----------------------------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This is a variant of the UnifyDivergentExitNodes pass. Rather than ensuring 11 // there is at most one ret and one unreachable instruction, it ensures there is 12 // at most one divergent exiting block. 13 // 14 // StructurizeCFG can't deal with multi-exit regions formed by branches to 15 // multiple return nodes. It is not desirable to structurize regions with 16 // uniform branches, so unifying those to the same return block as divergent 17 // branches inhibits use of scalar branching. It still can't deal with the case 18 // where one branch goes to return, and one unreachable. Replace unreachable in 19 // this case with a return. 20 // 21 //===----------------------------------------------------------------------===// 22 23 #include "AMDGPU.h" 24 #include "llvm/ADT/DepthFirstIterator.h" 25 #include "llvm/ADT/StringExtras.h" 26 #include "llvm/Analysis/DivergenceAnalysis.h" 27 #include "llvm/Analysis/PostDominators.h" 28 #include "llvm/Analysis/TargetTransformInfo.h" 29 #include "llvm/IR/BasicBlock.h" 30 #include "llvm/IR/CFG.h" 31 #include "llvm/IR/Function.h" 32 #include "llvm/IR/Instructions.h" 33 #include "llvm/IR/Type.h" 34 #include "llvm/Transforms/Scalar.h" 35 #include "llvm/Transforms/Utils/Local.h" 36 using namespace llvm; 37 38 #define DEBUG_TYPE "amdgpu-unify-divergent-exit-nodes" 39 40 namespace { 41 42 class AMDGPUUnifyDivergentExitNodes : public FunctionPass { 43 public: 44 static char ID; // Pass identification, replacement for typeid 45 AMDGPUUnifyDivergentExitNodes() : FunctionPass(ID) { 46 initializeAMDGPUUnifyDivergentExitNodesPass(*PassRegistry::getPassRegistry()); 47 } 48 49 // We can preserve non-critical-edgeness when we unify function exit nodes 50 void getAnalysisUsage(AnalysisUsage &AU) const override; 51 bool runOnFunction(Function &F) override; 52 }; 53 54 } 55 56 char AMDGPUUnifyDivergentExitNodes::ID = 0; 57 INITIALIZE_PASS_BEGIN(AMDGPUUnifyDivergentExitNodes, DEBUG_TYPE, 58 "Unify divergent function exit nodes", false, false) 59 INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) 60 INITIALIZE_PASS_DEPENDENCY(DivergenceAnalysis) 61 INITIALIZE_PASS_END(AMDGPUUnifyDivergentExitNodes, DEBUG_TYPE, 62 "Unify divergent function exit nodes", false, false) 63 64 char &llvm::AMDGPUUnifyDivergentExitNodesID = AMDGPUUnifyDivergentExitNodes::ID; 65 66 void AMDGPUUnifyDivergentExitNodes::getAnalysisUsage(AnalysisUsage &AU) const{ 67 // TODO: Preserve dominator tree. 68 AU.addRequired<PostDominatorTreeWrapperPass>(); 69 70 AU.addRequired<DivergenceAnalysis>(); 71 72 // No divergent values are changed, only blocks and branch edges. 73 AU.addPreserved<DivergenceAnalysis>(); 74 75 // We preserve the non-critical-edgeness property 76 AU.addPreservedID(BreakCriticalEdgesID); 77 78 // This is a cluster of orthogonal Transforms 79 AU.addPreservedID(LowerSwitchID); 80 FunctionPass::getAnalysisUsage(AU); 81 82 AU.addRequired<TargetTransformInfoWrapperPass>(); 83 } 84 85 /// \returns true if \p BB is reachable through only uniform branches. 86 /// XXX - Is there a more efficient way to find this? 87 static bool isUniformlyReached(const DivergenceAnalysis &DA, 88 BasicBlock &BB) { 89 SmallVector<BasicBlock *, 8> Stack; 90 SmallPtrSet<BasicBlock *, 8> Visited; 91 92 for (BasicBlock *Pred : predecessors(&BB)) 93 Stack.push_back(Pred); 94 95 while (!Stack.empty()) { 96 BasicBlock *Top = Stack.pop_back_val(); 97 if (!DA.isUniform(Top->getTerminator())) 98 return false; 99 100 for (BasicBlock *Pred : predecessors(Top)) { 101 if (Visited.insert(Pred).second) 102 Stack.push_back(Pred); 103 } 104 } 105 106 return true; 107 } 108 109 static BasicBlock *unifyReturnBlockSet(Function &F, 110 ArrayRef<BasicBlock *> ReturningBlocks, 111 const TargetTransformInfo &TTI, 112 StringRef Name) { 113 // Otherwise, we need to insert a new basic block into the function, add a PHI 114 // nodes (if the function returns values), and convert all of the return 115 // instructions into unconditional branches. 116 // 117 BasicBlock *NewRetBlock = BasicBlock::Create(F.getContext(), Name, &F); 118 119 PHINode *PN = nullptr; 120 if (F.getReturnType()->isVoidTy()) { 121 ReturnInst::Create(F.getContext(), nullptr, NewRetBlock); 122 } else { 123 // If the function doesn't return void... add a PHI node to the block... 124 PN = PHINode::Create(F.getReturnType(), ReturningBlocks.size(), 125 "UnifiedRetVal"); 126 NewRetBlock->getInstList().push_back(PN); 127 ReturnInst::Create(F.getContext(), PN, NewRetBlock); 128 } 129 130 // Loop over all of the blocks, replacing the return instruction with an 131 // unconditional branch. 132 // 133 for (BasicBlock *BB : ReturningBlocks) { 134 // Add an incoming element to the PHI node for every return instruction that 135 // is merging into this new block... 136 if (PN) 137 PN->addIncoming(BB->getTerminator()->getOperand(0), BB); 138 139 BB->getInstList().pop_back(); // Remove the return insn 140 BranchInst::Create(NewRetBlock, BB); 141 } 142 143 for (BasicBlock *BB : ReturningBlocks) { 144 // Cleanup possible branch to unconditional branch to the return. 145 SimplifyCFG(BB, TTI, 2); 146 } 147 148 return NewRetBlock; 149 } 150 151 bool AMDGPUUnifyDivergentExitNodes::runOnFunction(Function &F) { 152 auto &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree(); 153 if (PDT.getRoots().size() <= 1) 154 return false; 155 156 DivergenceAnalysis &DA = getAnalysis<DivergenceAnalysis>(); 157 158 // Loop over all of the blocks in a function, tracking all of the blocks that 159 // return. 160 // 161 SmallVector<BasicBlock *, 4> ReturningBlocks; 162 SmallVector<BasicBlock *, 4> UnreachableBlocks; 163 164 for (BasicBlock *BB : PDT.getRoots()) { 165 if (isa<ReturnInst>(BB->getTerminator())) { 166 if (!isUniformlyReached(DA, *BB)) 167 ReturningBlocks.push_back(BB); 168 } else if (isa<UnreachableInst>(BB->getTerminator())) { 169 if (!isUniformlyReached(DA, *BB)) 170 UnreachableBlocks.push_back(BB); 171 } 172 } 173 174 if (!UnreachableBlocks.empty()) { 175 BasicBlock *UnreachableBlock = nullptr; 176 177 if (UnreachableBlocks.size() == 1) { 178 UnreachableBlock = UnreachableBlocks.front(); 179 } else { 180 UnreachableBlock = BasicBlock::Create(F.getContext(), 181 "UnifiedUnreachableBlock", &F); 182 new UnreachableInst(F.getContext(), UnreachableBlock); 183 184 for (BasicBlock *BB : UnreachableBlocks) { 185 BB->getInstList().pop_back(); // Remove the unreachable inst. 186 BranchInst::Create(UnreachableBlock, BB); 187 } 188 } 189 190 if (!ReturningBlocks.empty()) { 191 // Don't create a new unreachable inst if we have a return. The 192 // structurizer/annotator can't handle the multiple exits 193 194 Type *RetTy = F.getReturnType(); 195 Value *RetVal = RetTy->isVoidTy() ? nullptr : UndefValue::get(RetTy); 196 UnreachableBlock->getInstList().pop_back(); // Remove the unreachable inst. 197 198 Function *UnreachableIntrin = 199 Intrinsic::getDeclaration(F.getParent(), Intrinsic::amdgcn_unreachable); 200 201 // Insert a call to an intrinsic tracking that this is an unreachable 202 // point, in case we want to kill the active lanes or something later. 203 CallInst::Create(UnreachableIntrin, {}, "", UnreachableBlock); 204 205 // Don't create a scalar trap. We would only want to trap if this code was 206 // really reached, but a scalar trap would happen even if no lanes 207 // actually reached here. 208 ReturnInst::Create(F.getContext(), RetVal, UnreachableBlock); 209 ReturningBlocks.push_back(UnreachableBlock); 210 } 211 } 212 213 // Now handle return blocks. 214 if (ReturningBlocks.empty()) 215 return false; // No blocks return 216 217 if (ReturningBlocks.size() == 1) 218 return false; // Already has a single return block 219 220 const TargetTransformInfo &TTI 221 = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 222 223 unifyReturnBlockSet(F, ReturningBlocks, TTI, "UnifiedReturnBlock"); 224 return true; 225 } 226