10b57cec5SDimitry Andric //===- UnifyFunctionExitNodes.cpp - Make all functions have a single exit -===//
20b57cec5SDimitry Andric //
30b57cec5SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
40b57cec5SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
50b57cec5SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60b57cec5SDimitry Andric //
70b57cec5SDimitry Andric //===----------------------------------------------------------------------===//
80b57cec5SDimitry Andric //
9*af732203SDimitry Andric // This pass is used to ensure that functions have at most one return and one
10*af732203SDimitry Andric // unreachable instruction in them.
110b57cec5SDimitry Andric //
120b57cec5SDimitry Andric //===----------------------------------------------------------------------===//
130b57cec5SDimitry Andric 
140b57cec5SDimitry Andric #include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h"
150b57cec5SDimitry Andric #include "llvm/IR/BasicBlock.h"
160b57cec5SDimitry Andric #include "llvm/IR/Function.h"
170b57cec5SDimitry Andric #include "llvm/IR/Instructions.h"
180b57cec5SDimitry Andric #include "llvm/IR/Type.h"
19480093f4SDimitry Andric #include "llvm/InitializePasses.h"
200b57cec5SDimitry Andric #include "llvm/Transforms/Utils.h"
210b57cec5SDimitry Andric using namespace llvm;
220b57cec5SDimitry Andric 
23*af732203SDimitry Andric char UnifyFunctionExitNodesLegacyPass::ID = 0;
24480093f4SDimitry Andric 
UnifyFunctionExitNodesLegacyPass()25*af732203SDimitry Andric UnifyFunctionExitNodesLegacyPass::UnifyFunctionExitNodesLegacyPass()
26*af732203SDimitry Andric     : FunctionPass(ID) {
27*af732203SDimitry Andric   initializeUnifyFunctionExitNodesLegacyPassPass(
28*af732203SDimitry Andric       *PassRegistry::getPassRegistry());
29480093f4SDimitry Andric }
30480093f4SDimitry Andric 
31*af732203SDimitry Andric INITIALIZE_PASS(UnifyFunctionExitNodesLegacyPass, "mergereturn",
320b57cec5SDimitry Andric                 "Unify function exit nodes", false, false)
330b57cec5SDimitry Andric 
createUnifyFunctionExitNodesPass()340b57cec5SDimitry Andric Pass *llvm::createUnifyFunctionExitNodesPass() {
35*af732203SDimitry Andric   return new UnifyFunctionExitNodesLegacyPass();
360b57cec5SDimitry Andric }
370b57cec5SDimitry Andric 
getAnalysisUsage(AnalysisUsage & AU) const38*af732203SDimitry Andric void UnifyFunctionExitNodesLegacyPass::getAnalysisUsage(
39*af732203SDimitry Andric     AnalysisUsage &AU) const {
400b57cec5SDimitry Andric   // We preserve the non-critical-edgeness property
410b57cec5SDimitry Andric   AU.addPreservedID(BreakCriticalEdgesID);
420b57cec5SDimitry Andric   // This is a cluster of orthogonal Transforms
430b57cec5SDimitry Andric   AU.addPreservedID(LowerSwitchID);
440b57cec5SDimitry Andric }
450b57cec5SDimitry Andric 
46*af732203SDimitry Andric namespace {
47*af732203SDimitry Andric 
unifyUnreachableBlocks(Function & F)48*af732203SDimitry Andric bool unifyUnreachableBlocks(Function &F) {
490b57cec5SDimitry Andric   std::vector<BasicBlock *> UnreachableBlocks;
50*af732203SDimitry Andric 
510b57cec5SDimitry Andric   for (BasicBlock &I : F)
52*af732203SDimitry Andric     if (isa<UnreachableInst>(I.getTerminator()))
530b57cec5SDimitry Andric       UnreachableBlocks.push_back(&I);
540b57cec5SDimitry Andric 
55*af732203SDimitry Andric   if (UnreachableBlocks.size() <= 1)
56*af732203SDimitry Andric     return false;
57*af732203SDimitry Andric 
58*af732203SDimitry Andric   BasicBlock *UnreachableBlock =
59*af732203SDimitry Andric       BasicBlock::Create(F.getContext(), "UnifiedUnreachableBlock", &F);
600b57cec5SDimitry Andric   new UnreachableInst(F.getContext(), UnreachableBlock);
610b57cec5SDimitry Andric 
620b57cec5SDimitry Andric   for (BasicBlock *BB : UnreachableBlocks) {
630b57cec5SDimitry Andric     BB->getInstList().pop_back(); // Remove the unreachable inst.
640b57cec5SDimitry Andric     BranchInst::Create(UnreachableBlock, BB);
650b57cec5SDimitry Andric   }
66*af732203SDimitry Andric 
67*af732203SDimitry Andric   return true;
680b57cec5SDimitry Andric }
690b57cec5SDimitry Andric 
unifyReturnBlocks(Function & F)70*af732203SDimitry Andric bool unifyReturnBlocks(Function &F) {
71*af732203SDimitry Andric   std::vector<BasicBlock *> ReturningBlocks;
72*af732203SDimitry Andric 
73*af732203SDimitry Andric   for (BasicBlock &I : F)
74*af732203SDimitry Andric     if (isa<ReturnInst>(I.getTerminator()))
75*af732203SDimitry Andric       ReturningBlocks.push_back(&I);
76*af732203SDimitry Andric 
77*af732203SDimitry Andric   if (ReturningBlocks.size() <= 1)
780b57cec5SDimitry Andric     return false;
790b57cec5SDimitry Andric 
80*af732203SDimitry Andric   // Insert a new basic block into the function, add PHI nodes (if the function
81*af732203SDimitry Andric   // returns values), and convert all of the return instructions into
82*af732203SDimitry Andric   // unconditional branches.
830b57cec5SDimitry Andric   BasicBlock *NewRetBlock = BasicBlock::Create(F.getContext(),
840b57cec5SDimitry Andric                                                "UnifiedReturnBlock", &F);
850b57cec5SDimitry Andric 
860b57cec5SDimitry Andric   PHINode *PN = nullptr;
870b57cec5SDimitry Andric   if (F.getReturnType()->isVoidTy()) {
880b57cec5SDimitry Andric     ReturnInst::Create(F.getContext(), nullptr, NewRetBlock);
890b57cec5SDimitry Andric   } else {
900b57cec5SDimitry Andric     // If the function doesn't return void... add a PHI node to the block...
910b57cec5SDimitry Andric     PN = PHINode::Create(F.getReturnType(), ReturningBlocks.size(),
920b57cec5SDimitry Andric                          "UnifiedRetVal");
930b57cec5SDimitry Andric     NewRetBlock->getInstList().push_back(PN);
940b57cec5SDimitry Andric     ReturnInst::Create(F.getContext(), PN, NewRetBlock);
950b57cec5SDimitry Andric   }
960b57cec5SDimitry Andric 
970b57cec5SDimitry Andric   // Loop over all of the blocks, replacing the return instruction with an
980b57cec5SDimitry Andric   // unconditional branch.
990b57cec5SDimitry Andric   for (BasicBlock *BB : ReturningBlocks) {
1000b57cec5SDimitry Andric     // Add an incoming element to the PHI node for every return instruction that
1010b57cec5SDimitry Andric     // is merging into this new block...
1020b57cec5SDimitry Andric     if (PN)
1030b57cec5SDimitry Andric       PN->addIncoming(BB->getTerminator()->getOperand(0), BB);
1040b57cec5SDimitry Andric 
1050b57cec5SDimitry Andric     BB->getInstList().pop_back();  // Remove the return insn
1060b57cec5SDimitry Andric     BranchInst::Create(NewRetBlock, BB);
1070b57cec5SDimitry Andric   }
108*af732203SDimitry Andric 
1090b57cec5SDimitry Andric   return true;
1100b57cec5SDimitry Andric }
111*af732203SDimitry Andric } // namespace
112*af732203SDimitry Andric 
113*af732203SDimitry Andric // Unify all exit nodes of the CFG by creating a new BasicBlock, and converting
114*af732203SDimitry Andric // all returns to unconditional branches to this new basic block. Also, unify
115*af732203SDimitry Andric // all unreachable blocks.
runOnFunction(Function & F)116*af732203SDimitry Andric bool UnifyFunctionExitNodesLegacyPass::runOnFunction(Function &F) {
117*af732203SDimitry Andric   bool Changed = false;
118*af732203SDimitry Andric   Changed |= unifyUnreachableBlocks(F);
119*af732203SDimitry Andric   Changed |= unifyReturnBlocks(F);
120*af732203SDimitry Andric   return Changed;
121*af732203SDimitry Andric }
122*af732203SDimitry Andric 
run(Function & F,FunctionAnalysisManager & AM)123*af732203SDimitry Andric PreservedAnalyses UnifyFunctionExitNodesPass::run(Function &F,
124*af732203SDimitry Andric                                                   FunctionAnalysisManager &AM) {
125*af732203SDimitry Andric   bool Changed = false;
126*af732203SDimitry Andric   Changed |= unifyUnreachableBlocks(F);
127*af732203SDimitry Andric   Changed |= unifyReturnBlocks(F);
128*af732203SDimitry Andric   return Changed ? PreservedAnalyses() : PreservedAnalyses::all();
129*af732203SDimitry Andric }
130