1 //===--- ExtTSPReorderAlgorithm.cpp - Order basic blocks ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // ExtTSP - layout of basic blocks with i-cache optimization.
9 //
10 // The algorithm is a greedy heuristic that works with chains (ordered lists)
11 // of basic blocks. Initially all chains are isolated basic blocks. On every
12 // iteration, we pick a pair of chains whose merging yields the biggest increase
13 // in the ExtTSP value, which models how i-cache "friendly" a specific chain is.
14 // A pair of chains giving the maximum gain is merged into a new chain. The
15 // procedure stops when there is only one chain left, or when merging does not
16 // increase ExtTSP. In the latter case, the remaining chains are sorted by
17 // density in decreasing order.
18 //
19 // An important aspect is the way two chains are merged. Unlike earlier
20 // algorithms (e.g., OptimizeCacheReorderAlgorithm or Pettis-Hansen), two
21 // chains, X and Y, are first split into three, X1, X2, and Y. Then we
22 // consider all possible ways of gluing the three chains (e.g., X1YX2, X1X2Y,
23 // X2X1Y, X2YX1, YX1X2, YX2X1) and choose the one producing the largest score.
24 // This improves the quality of the final result (the search space is larger)
25 // while keeping the implementation sufficiently fast.
26 //
27 // Reference:
28 //   * A. Newell and S. Pupyrev, Improved Basic Block Reordering,
29 //     IEEE Transactions on Computers, 2020
30 //     https://arxiv.org/abs/1809.04676
31 //===----------------------------------------------------------------------===//
32 #include "bolt/Core/BinaryBasicBlock.h"
33 #include "bolt/Core/BinaryFunction.h"
34 #include "bolt/Passes/ReorderAlgorithm.h"
35 #include "llvm/Support/CommandLine.h"
36 
37 using namespace llvm;
38 using namespace bolt;
39 namespace opts {
40 
41 extern cl::OptionCategory BoltOptCategory;
42 extern cl::opt<bool> NoThreads;
43 
44 cl::opt<unsigned>
45 ChainSplitThreshold("chain-split-threshold",
46   cl::desc("The maximum size of a chain to apply splitting"),
47   cl::init(128),
48   cl::ReallyHidden,
49   cl::ZeroOrMore,
50   cl::cat(BoltOptCategory));
51 
52 cl::opt<double>
53 ForwardWeight("forward-weight",
54   cl::desc("The weight of forward jumps for ExtTSP value"),
55   cl::init(0.1),
56   cl::ReallyHidden,
57   cl::ZeroOrMore,
58   cl::cat(BoltOptCategory));
59 
60 cl::opt<double>
61 BackwardWeight("backward-weight",
62   cl::desc("The weight of backward jumps for ExtTSP value"),
63   cl::init(0.1),
64   cl::ReallyHidden,
65   cl::ZeroOrMore,
66   cl::cat(BoltOptCategory));
67 
68 cl::opt<unsigned>
69 ForwardDistance("forward-distance",
70   cl::desc("The maximum distance (in bytes) of forward jumps for ExtTSP value"),
71   cl::init(1024),
72   cl::ReallyHidden,
73   cl::ZeroOrMore,
74   cl::cat(BoltOptCategory));
75 
76 cl::opt<unsigned>
77 BackwardDistance("backward-distance",
78   cl::desc("The maximum distance (in bytes) of backward jumps for ExtTSP value"),
79   cl::init(640),
80   cl::ReallyHidden,
81   cl::ZeroOrMore,
82   cl::cat(BoltOptCategory));
83 
84 }
85 
86 namespace llvm {
87 namespace bolt {
88 
89 // Epsilon for comparison of doubles
90 constexpr double EPS = 1e-8;
91 
92 class Block;
93 class Chain;
94 class Edge;
95 
96 // Calculate Ext-TSP value, which quantifies the expected number of i-cache
97 // misses for a given ordering of basic blocks
98 double extTSPScore(uint64_t SrcAddr,
99                    uint64_t SrcSize,
100                    uint64_t DstAddr,
101                    uint64_t Count) {
102   assert(Count != BinaryBasicBlock::COUNT_NO_PROFILE);
103 
104   // Fallthrough
105   if (SrcAddr + SrcSize == DstAddr) {
106     // Assume that FallthroughWeight = 1.0 after normalization
107     return static_cast<double>(Count);
108   }
109   // Forward
110   if (SrcAddr + SrcSize < DstAddr) {
111     const uint64_t Dist = DstAddr - (SrcAddr + SrcSize);
112     if (Dist <= opts::ForwardDistance) {
113       double Prob = 1.0 - static_cast<double>(Dist) / opts::ForwardDistance;
114       return opts::ForwardWeight * Prob * Count;
115     }
116     return 0;
117   }
118   // Backward
119   const uint64_t Dist = SrcAddr + SrcSize - DstAddr;
120   if (Dist <= opts::BackwardDistance) {
121     double Prob = 1.0 - static_cast<double>(Dist) / opts::BackwardDistance;
122     return opts::BackwardWeight * Prob * Count;
123   }
124   return 0;
125 }
126 
127 using BlockPair = std::pair<Block *, Block *>;
128 using JumpList = std::vector<std::pair<BlockPair, uint64_t>>;
129 using BlockIter = std::vector<Block *>::const_iterator;
130 
131 enum MergeTypeTy {
132   X_Y = 0,
133   X1_Y_X2 = 1,
134   Y_X2_X1 = 2,
135   X2_X1_Y = 3,
136 };
137 
138 class MergeGainTy {
139 public:
140   explicit MergeGainTy() {}
141   explicit MergeGainTy(double Score, size_t MergeOffset, MergeTypeTy MergeType)
142     : Score(Score),
143       MergeOffset(MergeOffset),
144       MergeType(MergeType) {}
145 
146   double score() const {
147     return Score;
148   }
149 
150   size_t mergeOffset() const {
151     return MergeOffset;
152   }
153 
154   MergeTypeTy mergeType() const {
155     return MergeType;
156   }
157 
158   // returns 'true' iff Other is preferred over this
159   bool operator < (const MergeGainTy& Other) const {
160     return (Other.Score > EPS && Other.Score > Score + EPS);
161   }
162 
163 private:
164   double Score{-1.0};
165   size_t MergeOffset{0};
166   MergeTypeTy MergeType{MergeTypeTy::X_Y};
167 };
168 
169 // A node in CFG corresponding to a BinaryBasicBlock.
170 // The class wraps several mutable fields utilized in the ExtTSP algorithm
171 class Block {
172 public:
173   Block(const Block&) = delete;
174   Block(Block&&) = default;
175   Block& operator=(const Block&) = delete;
176   Block& operator=(Block&&) = default;
177 
178   // Corresponding basic block
179   BinaryBasicBlock *BB{nullptr};
180   // Current chain of the basic block
181   Chain *CurChain{nullptr};
182   // (Estimated) size of the block in the binary
183   uint64_t Size{0};
184   // Execution count of the block in the binary
185   uint64_t ExecutionCount{0};
186   // An original index of the node in CFG
187   size_t Index{0};
188   // The index of the block in the current chain
189   size_t CurIndex{0};
190   // An offset of the block in the current chain
191   mutable uint64_t EstimatedAddr{0};
192   // Fallthrough successor of the node in CFG
193   Block *FallthroughSucc{nullptr};
194   // Fallthrough predecessor of the node in CFG
195   Block *FallthroughPred{nullptr};
196   // Outgoing jumps from the block
197   std::vector<std::pair<Block *, uint64_t>> OutJumps;
198   // Incoming jumps to the block
199   std::vector<std::pair<Block *, uint64_t>> InJumps;
200   // Total execution count of incoming jumps
201   uint64_t InWeight{0};
202   // Total execution count of outgoing jumps
203   uint64_t OutWeight{0};
204 
205 public:
206   explicit Block(BinaryBasicBlock *BB_, uint64_t Size_)
207     : BB(BB_),
208       Size(Size_),
209       ExecutionCount(BB_->getKnownExecutionCount()),
210       Index(BB->getLayoutIndex()) {}
211 
212   bool adjacent(const Block *Other) const {
213     return hasOutJump(Other) || hasInJump(Other);
214   }
215 
216   bool hasOutJump(const Block *Other) const {
217     for (std::pair<Block *, uint64_t> Jump : OutJumps) {
218       if (Jump.first == Other)
219         return true;
220     }
221     return false;
222   }
223 
224   bool hasInJump(const Block *Other) const {
225     for (std::pair<Block *, uint64_t> Jump : InJumps) {
226       if (Jump.first == Other)
227         return true;
228     }
229     return false;
230   }
231 };
232 
233 // A chain (ordered sequence) of CFG nodes (basic blocks)
234 class Chain {
235 public:
236   Chain(const Chain&) = delete;
237   Chain(Chain&&) = default;
238   Chain& operator=(const Chain&) = delete;
239   Chain& operator=(Chain&&) = default;
240 
241   explicit Chain(size_t Id, Block *Block)
242     : Id(Id),
243       IsEntry(Block->Index == 0),
244       ExecutionCount(Block->ExecutionCount),
245       Size(Block->Size),
246       Score(0),
247       Blocks(1, Block) {}
248 
249   size_t id() const {
250     return Id;
251   }
252 
253   uint64_t size() const {
254     return Size;
255   }
256 
257   double density() const {
258     return static_cast<double>(ExecutionCount) / Size;
259   }
260 
261   uint64_t executionCount() const {
262     return ExecutionCount;
263   }
264 
265   bool isEntryPoint() const {
266     return IsEntry;
267   }
268 
269   double score() const {
270     return Score;
271   }
272 
273   void setScore(double NewScore) {
274     Score = NewScore;
275   }
276 
277   const std::vector<Block *> &blocks() const {
278     return Blocks;
279   }
280 
281   const std::vector<std::pair<Chain *, Edge *>> &edges() const {
282     return Edges;
283   }
284 
285   Edge *getEdge(Chain *Other) const {
286     for (std::pair<Chain *, Edge *> It : Edges) {
287       if (It.first == Other)
288         return It.second;
289     }
290     return nullptr;
291   }
292 
293   void removeEdge(Chain *Other) {
294     auto It = Edges.begin();
295     while (It != Edges.end()) {
296       if (It->first == Other) {
297         Edges.erase(It);
298         return;
299       }
300       It++;
301     }
302   }
303 
304   void addEdge(Chain *Other, Edge *Edge) { Edges.emplace_back(Other, Edge); }
305 
306   void merge(Chain *Other, const std::vector<Block *> &MergedBlocks) {
307     Blocks = MergedBlocks;
308     IsEntry |= Other->IsEntry;
309     ExecutionCount += Other->ExecutionCount;
310     Size += Other->Size;
311     // Update block's chains
312     for (size_t Idx = 0; Idx < Blocks.size(); Idx++) {
313       Blocks[Idx]->CurChain = this;
314       Blocks[Idx]->CurIndex = Idx;
315     }
316   }
317 
318   void mergeEdges(Chain *Other);
319 
320   void clear() {
321     Blocks.clear();
322     Edges.clear();
323   }
324 
325 private:
326   size_t Id;
327   bool IsEntry;
328   uint64_t ExecutionCount;
329   uint64_t Size;
330   // Cached ext-tsp score for the chain
331   double Score;
332   // Blocks of the chain
333   std::vector<Block *> Blocks;
334   // Adjacent chains and corresponding edges (lists of jumps)
335   std::vector<std::pair<Chain *, Edge *>> Edges;
336 };
337 
338 // An edge in CFG reprsenting jumps between chains of BinaryBasicBlocks.
339 // When blocks are merged into chains, the edges are combined too so that
340 // there is always at most one edge between a pair of chains
341 class Edge {
342 public:
343   Edge(const Edge&) = delete;
344   Edge(Edge&&) = default;
345   Edge& operator=(const Edge&) = delete;
346   Edge& operator=(Edge&&) = default;
347 
348   explicit Edge(Block *SrcBlock, Block *DstBlock, uint64_t EC)
349     : SrcChain(SrcBlock->CurChain),
350       DstChain(DstBlock->CurChain),
351       Jumps(1, std::make_pair(std::make_pair(SrcBlock, DstBlock), EC)) {}
352 
353   const JumpList &jumps() const {
354     return Jumps;
355   }
356 
357   void changeEndpoint(Chain *From, Chain *To) {
358     if (From == SrcChain)
359       SrcChain = To;
360     if (From == DstChain)
361       DstChain = To;
362   }
363 
364   void appendJump(Block *SrcBlock, Block *DstBlock, uint64_t EC) {
365     Jumps.emplace_back(std::make_pair(SrcBlock, DstBlock), EC);
366   }
367 
368   void moveJumps(Edge *Other) {
369     Jumps.insert(Jumps.end(), Other->Jumps.begin(), Other->Jumps.end());
370     Other->Jumps.clear();
371   }
372 
373   bool hasCachedMergeGain(Chain *Src, Chain *Dst) const {
374     return Src == SrcChain ? CacheValidForward : CacheValidBackward;
375   }
376 
377   MergeGainTy getCachedMergeGain(Chain *Src, Chain *Dst) const {
378     return Src == SrcChain ? CachedGainForward : CachedGainBackward;
379   }
380 
381   void setCachedMergeGain(Chain *Src, Chain *Dst, MergeGainTy MergeGain) {
382     if (Src == SrcChain) {
383       CachedGainForward = MergeGain;
384       CacheValidForward = true;
385     } else {
386       CachedGainBackward = MergeGain;
387       CacheValidBackward = true;
388     }
389   }
390 
391   void invalidateCache() {
392     CacheValidForward = false;
393     CacheValidBackward = false;
394   }
395 
396 private:
397   Chain *SrcChain{nullptr};
398   Chain *DstChain{nullptr};
399   // Original jumps in the binary with correspinding execution counts
400   JumpList Jumps;
401   // Cached ext-tsp value for merging the pair of chains
402   // Since the gain of merging (Src, Dst) and (Dst, Src) might be different,
403   // we store both values here
404   MergeGainTy CachedGainForward;
405   MergeGainTy CachedGainBackward;
406   // Whether the cached value must be recomputed
407   bool CacheValidForward{false};
408   bool CacheValidBackward{false};
409 };
410 
411 void Chain::mergeEdges(Chain *Other) {
412   assert(this != Other && "cannot merge a chain with itself");
413 
414   // Update edges adjacent to chain Other
415   for (auto EdgeIt : Other->Edges) {
416     Chain *const DstChain = EdgeIt.first;
417     Edge *const DstEdge = EdgeIt.second;
418     Chain *const TargetChain = DstChain == Other ? this : DstChain;
419 
420     // Find the corresponding edge in the current chain
421     Edge *curEdge = getEdge(TargetChain);
422     if (curEdge == nullptr) {
423       DstEdge->changeEndpoint(Other, this);
424       this->addEdge(TargetChain, DstEdge);
425       if (DstChain != this && DstChain != Other) {
426         DstChain->addEdge(this, DstEdge);
427       }
428     } else {
429       curEdge->moveJumps(DstEdge);
430     }
431     // Cleanup leftover edge
432     if (DstChain != Other) {
433       DstChain->removeEdge(Other);
434     }
435   }
436 }
437 
438 // A wrapper around three chains of basic blocks; it is used to avoid extra
439 // instantiation of the vectors.
440 class MergedChain {
441 public:
442   MergedChain(BlockIter Begin1,
443               BlockIter End1,
444               BlockIter Begin2 = BlockIter(),
445               BlockIter End2 = BlockIter(),
446               BlockIter Begin3 = BlockIter(),
447               BlockIter End3 = BlockIter())
448   : Begin1(Begin1),
449     End1(End1),
450     Begin2(Begin2),
451     End2(End2),
452     Begin3(Begin3),
453     End3(End3) {}
454 
455   template<typename F>
456   void forEach(const F &Func) const {
457     for (auto It = Begin1; It != End1; It++)
458       Func(*It);
459     for (auto It = Begin2; It != End2; It++)
460       Func(*It);
461     for (auto It = Begin3; It != End3; It++)
462       Func(*It);
463   }
464 
465   std::vector<Block *> getBlocks() const {
466     std::vector<Block *> Result;
467     Result.reserve(std::distance(Begin1, End1) +
468                    std::distance(Begin2, End2) +
469                    std::distance(Begin3, End3));
470     Result.insert(Result.end(), Begin1, End1);
471     Result.insert(Result.end(), Begin2, End2);
472     Result.insert(Result.end(), Begin3, End3);
473     return Result;
474   }
475 
476   const Block *getFirstBlock() const {
477     return *Begin1;
478   }
479 
480 private:
481   BlockIter Begin1;
482   BlockIter End1;
483   BlockIter Begin2;
484   BlockIter End2;
485   BlockIter Begin3;
486   BlockIter End3;
487 };
488 
489 /// Deterministically compare pairs of chains
490 bool compareChainPairs(const Chain *A1, const Chain *B1,
491                        const Chain *A2, const Chain *B2) {
492   const uint64_t Samples1 = A1->executionCount() + B1->executionCount();
493   const uint64_t Samples2 = A2->executionCount() + B2->executionCount();
494   if (Samples1 != Samples2)
495     return Samples1 < Samples2;
496 
497   // Making the order deterministic
498   if (A1 != A2)
499     return A1->id() < A2->id();
500   return B1->id() < B2->id();
501 }
502 class ExtTSP {
503 public:
504   ExtTSP(const BinaryFunction &BF) : BF(BF) {
505     initialize();
506   }
507 
508   /// Run the algorithm and return an ordering of basic block
509   void run(BinaryFunction::BasicBlockOrderType &Order) {
510     // Pass 1: Merge blocks with their fallthrough successors
511     mergeFallthroughs();
512 
513     // Pass 2: Merge pairs of chains while improving the ExtTSP objective
514     mergeChainPairs();
515 
516     // Pass 3: Merge cold blocks to reduce code size
517     mergeColdChains();
518 
519     // Collect blocks from all chains
520     concatChains(Order);
521   }
522 
523 private:
524   /// Initialize algorithm's data structures
525   void initialize() {
526     // Create a separate MCCodeEmitter to allow lock-free execution
527     BinaryContext::IndependentCodeEmitter Emitter;
528     if (!opts::NoThreads) {
529       Emitter = BF.getBinaryContext().createIndependentMCCodeEmitter();
530     }
531 
532     // Initialize CFG nodes
533     AllBlocks.reserve(BF.layout_size());
534     size_t LayoutIndex = 0;
535     for (BinaryBasicBlock *BB : BF.layout()) {
536       BB->setLayoutIndex(LayoutIndex++);
537       uint64_t Size =
538           std::max<uint64_t>(BB->estimateSize(Emitter.MCE.get()), 1);
539       AllBlocks.emplace_back(BB, Size);
540     }
541 
542     // Initialize edges for the blocks and compute their total in/out weights
543     size_t NumEdges = 0;
544     for (Block &Block : AllBlocks) {
545       auto BI = Block.BB->branch_info_begin();
546       for (BinaryBasicBlock *SuccBB : Block.BB->successors()) {
547         assert(BI->Count != BinaryBasicBlock::COUNT_NO_PROFILE &&
548                "missing profile for a jump");
549         if (SuccBB != Block.BB && BI->Count > 0) {
550           class Block &SuccBlock = AllBlocks[SuccBB->getLayoutIndex()];
551           uint64_t Count = BI->Count;
552           SuccBlock.InWeight += Count;
553           SuccBlock.InJumps.emplace_back(&Block, Count);
554           Block.OutWeight += Count;
555           Block.OutJumps.emplace_back(&SuccBlock, Count);
556           NumEdges++;
557         }
558         ++BI;
559       }
560     }
561 
562     // Initialize execution count for every basic block, which is the
563     // maximum over the sums of all in and out edge weights.
564     // Also execution count of the entry point is set to at least 1
565     for (Block &Block : AllBlocks) {
566       size_t Index = Block.Index;
567       Block.ExecutionCount = std::max(Block.ExecutionCount, Block.InWeight);
568       Block.ExecutionCount = std::max(Block.ExecutionCount, Block.OutWeight);
569       if (Index == 0 && Block.ExecutionCount == 0)
570         Block.ExecutionCount = 1;
571     }
572 
573     // Initialize chains
574     AllChains.reserve(BF.layout_size());
575     HotChains.reserve(BF.layout_size());
576     for (Block &Block : AllBlocks) {
577       AllChains.emplace_back(Block.Index, &Block);
578       Block.CurChain = &AllChains.back();
579       if (Block.ExecutionCount > 0) {
580         HotChains.push_back(&AllChains.back());
581       }
582     }
583 
584     // Initialize edges
585     AllEdges.reserve(NumEdges);
586     for (Block &Block : AllBlocks) {
587       for (std::pair<class Block *, uint64_t> &Jump : Block.OutJumps) {
588         class Block *const SuccBlock = Jump.first;
589         Edge *CurEdge = Block.CurChain->getEdge(SuccBlock->CurChain);
590         // this edge is already present in the graph
591         if (CurEdge != nullptr) {
592           assert(SuccBlock->CurChain->getEdge(Block.CurChain) != nullptr);
593           CurEdge->appendJump(&Block, SuccBlock, Jump.second);
594           continue;
595         }
596         // this is a new edge
597         AllEdges.emplace_back(&Block, SuccBlock, Jump.second);
598         Block.CurChain->addEdge(SuccBlock->CurChain, &AllEdges.back());
599         SuccBlock->CurChain->addEdge(Block.CurChain, &AllEdges.back());
600       }
601     }
602     assert(AllEdges.size() <= NumEdges && "Incorrect number of created edges");
603   }
604 
605   /// For a pair of blocks, A and B, block B is the fallthrough successor of A,
606   /// if (i) all jumps (based on profile) from A goes to B and (ii) all jumps
607   /// to B are from A. Such blocks should be adjacent in an optimal ordering;
608   /// the method finds and merges such pairs of blocks
609   void mergeFallthroughs() {
610     // Find fallthroughs based on edge weights
611     for (Block &Block : AllBlocks) {
612       if (Block.BB->succ_size() == 1 &&
613           Block.BB->getSuccessor()->pred_size() == 1 &&
614           Block.BB->getSuccessor()->getLayoutIndex() != 0) {
615         size_t SuccIndex = Block.BB->getSuccessor()->getLayoutIndex();
616         Block.FallthroughSucc = &AllBlocks[SuccIndex];
617         AllBlocks[SuccIndex].FallthroughPred = &Block;
618         continue;
619       }
620 
621       if (Block.OutWeight == 0)
622         continue;
623       for (std::pair<class Block *, uint64_t> &Edge : Block.OutJumps) {
624         class Block *const SuccBlock = Edge.first;
625         // Successor cannot be the first BB, which is pinned
626         if (Block.OutWeight == Edge.second &&
627             SuccBlock->InWeight == Edge.second &&
628             SuccBlock->Index != 0) {
629           Block.FallthroughSucc = SuccBlock;
630           SuccBlock->FallthroughPred = &Block;
631           break;
632         }
633       }
634     }
635 
636     // There might be 'cycles' in the fallthrough dependencies (since profile
637     // data isn't 100% accurate).
638     // Break the cycles by choosing the block with smallest index as the tail
639     for (Block &Block : AllBlocks) {
640       if (Block.FallthroughSucc == nullptr || Block.FallthroughPred == nullptr)
641         continue;
642 
643       class Block *SuccBlock = Block.FallthroughSucc;
644       while (SuccBlock != nullptr && SuccBlock != &Block) {
645         SuccBlock = SuccBlock->FallthroughSucc;
646       }
647       if (SuccBlock == nullptr)
648         continue;
649       // break the cycle
650       AllBlocks[Block.FallthroughPred->Index].FallthroughSucc = nullptr;
651       Block.FallthroughPred = nullptr;
652     }
653 
654     // Merge blocks with their fallthrough successors
655     for (Block &Block : AllBlocks) {
656       if (Block.FallthroughPred == nullptr &&
657           Block.FallthroughSucc != nullptr) {
658         class Block *CurBlock = &Block;
659         while (CurBlock->FallthroughSucc != nullptr) {
660           class Block *const NextBlock = CurBlock->FallthroughSucc;
661           mergeChains(Block.CurChain, NextBlock->CurChain, 0, MergeTypeTy::X_Y);
662           CurBlock = NextBlock;
663         }
664       }
665     }
666   }
667 
668   /// Merge pairs of chains while improving the ExtTSP objective
669   void mergeChainPairs() {
670     while (HotChains.size() > 1) {
671       Chain *BestChainPred = nullptr;
672       Chain *BestChainSucc = nullptr;
673       auto BestGain = MergeGainTy();
674       // Iterate over all pairs of chains
675       for (Chain *ChainPred : HotChains) {
676         // Get candidates for merging with the current chain
677         for (auto EdgeIter : ChainPred->edges()) {
678           Chain *ChainSucc = EdgeIter.first;
679           Edge *ChainEdge = EdgeIter.second;
680           // Ignore loop edges
681           if (ChainPred == ChainSucc)
682             continue;
683 
684           // Compute the gain of merging the two chains
685           MergeGainTy CurGain = mergeGain(ChainPred, ChainSucc, ChainEdge);
686           if (CurGain.score() <= EPS)
687             continue;
688 
689           if (BestGain < CurGain ||
690               (std::abs(CurGain.score() - BestGain.score()) < EPS &&
691                compareChainPairs(ChainPred,
692                                  ChainSucc,
693                                  BestChainPred,
694                                  BestChainSucc))) {
695             BestGain = CurGain;
696             BestChainPred = ChainPred;
697             BestChainSucc = ChainSucc;
698           }
699         }
700       }
701 
702       // Stop merging when there is no improvement
703       if (BestGain.score() <= EPS)
704         break;
705 
706       // Merge the best pair of chains
707       mergeChains(BestChainPred,
708                   BestChainSucc,
709                   BestGain.mergeOffset(),
710                   BestGain.mergeType());
711     }
712   }
713 
714   /// Merge cold blocks to reduce code size
715   void mergeColdChains() {
716     for (BinaryBasicBlock *SrcBB : BF.layout()) {
717       // Iterating in reverse order to make sure original fallthrough jumps are
718       // merged first
719       for (auto Itr = SrcBB->succ_rbegin(); Itr != SrcBB->succ_rend(); ++Itr) {
720         BinaryBasicBlock *DstBB = *Itr;
721         size_t SrcIndex = SrcBB->getLayoutIndex();
722         size_t DstIndex = DstBB->getLayoutIndex();
723         Chain *SrcChain = AllBlocks[SrcIndex].CurChain;
724         Chain *DstChain = AllBlocks[DstIndex].CurChain;
725         if (SrcChain != DstChain && !DstChain->isEntryPoint() &&
726             SrcChain->blocks().back()->Index == SrcIndex &&
727             DstChain->blocks().front()->Index == DstIndex) {
728           mergeChains(SrcChain, DstChain, 0, MergeTypeTy::X_Y);
729         }
730       }
731     }
732   }
733 
734   /// Compute ExtTSP score for a given order of basic blocks
735   double score(const MergedChain &MergedBlocks, const JumpList &Jumps) const {
736     if (Jumps.empty())
737       return 0.0;
738     uint64_t CurAddr = 0;
739     MergedBlocks.forEach(
740       [&](const Block *BB) {
741         BB->EstimatedAddr = CurAddr;
742         CurAddr += BB->Size;
743       }
744     );
745 
746     double Score = 0;
747     for (const std::pair<std::pair<Block *, Block *>, uint64_t> &Jump : Jumps) {
748       const Block *SrcBlock = Jump.first.first;
749       const Block *DstBlock = Jump.first.second;
750       Score += extTSPScore(SrcBlock->EstimatedAddr,
751                            SrcBlock->Size,
752                            DstBlock->EstimatedAddr,
753                            Jump.second);
754     }
755     return Score;
756   }
757 
758   /// Compute the gain of merging two chains
759   ///
760   /// The function considers all possible ways of merging two chains and
761   /// computes the one having the largest increase in ExtTSP objective. The
762   /// result is a pair with the first element being the gain and the second
763   /// element being the corresponding merging type.
764   MergeGainTy mergeGain(Chain *ChainPred, Chain *ChainSucc, Edge *Edge) const {
765     if (Edge->hasCachedMergeGain(ChainPred, ChainSucc)) {
766       return Edge->getCachedMergeGain(ChainPred, ChainSucc);
767     }
768 
769     // Precompute jumps between ChainPred and ChainSucc
770     JumpList Jumps = Edge->jumps();
771     class Edge *EdgePP = ChainPred->getEdge(ChainPred);
772     if (EdgePP != nullptr)
773       Jumps.insert(Jumps.end(), EdgePP->jumps().begin(), EdgePP->jumps().end());
774     assert(Jumps.size() > 0 && "trying to merge chains w/o jumps");
775 
776     MergeGainTy Gain = MergeGainTy();
777     // Try to concatenate two chains w/o splitting
778     Gain = computeMergeGain(
779         Gain, ChainPred, ChainSucc, Jumps, 0, MergeTypeTy::X_Y);
780 
781     // Try to break ChainPred in various ways and concatenate with ChainSucc
782     if (ChainPred->blocks().size() <= opts::ChainSplitThreshold) {
783       for (size_t Offset = 1; Offset < ChainPred->blocks().size(); Offset++) {
784         Block *BB1 = ChainPred->blocks()[Offset - 1];
785         Block *BB2 = ChainPred->blocks()[Offset];
786         // Does the splitting break FT successors?
787         if (BB1->FallthroughSucc != nullptr) {
788           (void)BB2;
789           assert(BB1->FallthroughSucc == BB2 && "Fallthrough not preserved");
790           continue;
791         }
792 
793         Gain = computeMergeGain(
794             Gain, ChainPred, ChainSucc, Jumps, Offset, MergeTypeTy::X1_Y_X2);
795         Gain = computeMergeGain(
796             Gain, ChainPred, ChainSucc, Jumps, Offset, MergeTypeTy::Y_X2_X1);
797         Gain = computeMergeGain(
798             Gain, ChainPred, ChainSucc, Jumps, Offset, MergeTypeTy::X2_X1_Y);
799       }
800     }
801 
802     Edge->setCachedMergeGain(ChainPred, ChainSucc, Gain);
803     return Gain;
804   }
805 
806   /// Merge two chains and update the best Gain
807   MergeGainTy computeMergeGain(const MergeGainTy &CurGain,
808                                const Chain *ChainPred,
809                                const Chain *ChainSucc,
810                                const JumpList &Jumps,
811                                size_t MergeOffset,
812                                MergeTypeTy MergeType) const {
813     MergedChain MergedBlocks = mergeBlocks(
814         ChainPred->blocks(), ChainSucc->blocks(), MergeOffset, MergeType);
815 
816     // Do not allow a merge that does not preserve the original entry block
817     if ((ChainPred->isEntryPoint() || ChainSucc->isEntryPoint()) &&
818         MergedBlocks.getFirstBlock()->Index != 0)
819       return CurGain;
820 
821     // The gain for the new chain
822     const double NewScore = score(MergedBlocks, Jumps) - ChainPred->score();
823     auto NewGain = MergeGainTy(NewScore, MergeOffset, MergeType);
824     return CurGain < NewGain ? NewGain : CurGain;
825   }
826 
827   /// Merge two chains of blocks respecting a given merge 'type' and 'offset'
828   ///
829   /// If MergeType == 0, then the result is a concatentation of two chains.
830   /// Otherwise, the first chain is cut into two sub-chains at the offset,
831   /// and merged using all possible ways of concatenating three chains.
832   MergedChain mergeBlocks(const std::vector<Block *> &X,
833                           const std::vector<Block *> &Y,
834                           size_t MergeOffset,
835                           MergeTypeTy MergeType) const {
836     // Split the first chain, X, into X1 and X2
837     BlockIter BeginX1 = X.begin();
838     BlockIter EndX1 = X.begin() + MergeOffset;
839     BlockIter BeginX2 = X.begin() + MergeOffset;
840     BlockIter EndX2 = X.end();
841     BlockIter BeginY = Y.begin();
842     BlockIter EndY = Y.end();
843 
844     // Construct a new chain from the three existing ones
845     switch(MergeType) {
846       case MergeTypeTy::X_Y:
847         return MergedChain(BeginX1, EndX2, BeginY, EndY);
848       case MergeTypeTy::X1_Y_X2:
849         return MergedChain(BeginX1, EndX1, BeginY, EndY, BeginX2, EndX2);
850       case MergeTypeTy::Y_X2_X1:
851         return MergedChain(BeginY, EndY, BeginX2, EndX2, BeginX1, EndX1);
852       case MergeTypeTy::X2_X1_Y:
853         return MergedChain(BeginX2, EndX2, BeginX1, EndX1, BeginY, EndY);
854     }
855 
856     llvm_unreachable("unexpected merge type");
857   }
858 
859   /// Merge chain From into chain Into, update the list of active chains,
860   /// adjacency information, and the corresponding cached values
861   void mergeChains(Chain *Into,
862                    Chain *From,
863                    size_t MergeOffset,
864                    MergeTypeTy MergeType) {
865     assert(Into != From && "a chain cannot be merged with itself");
866 
867     // Merge the blocks
868     MergedChain MergedBlocks =
869         mergeBlocks(Into->blocks(), From->blocks(), MergeOffset, MergeType);
870     Into->merge(From, MergedBlocks.getBlocks());
871     Into->mergeEdges(From);
872     From->clear();
873 
874     // Update cached ext-tsp score for the new chain
875     Edge *SelfEdge = Into->getEdge(Into);
876     if (SelfEdge != nullptr) {
877       MergedBlocks = MergedChain(Into->blocks().begin(), Into->blocks().end());
878       Into->setScore(score(MergedBlocks, SelfEdge->jumps()));
879     }
880 
881     // Remove chain From from the list of active chains
882     auto Iter = std::remove(HotChains.begin(), HotChains.end(), From);
883     HotChains.erase(Iter, HotChains.end());
884 
885     // Invalidate caches
886     for (std::pair<Chain *, Edge *> EdgeIter : Into->edges()) {
887       EdgeIter.second->invalidateCache();
888     }
889   }
890 
891   /// Concatenate all chains into a final order
892   void concatChains(BinaryFunction::BasicBlockOrderType &Order) {
893     // Collect chains
894     std::vector<Chain *> SortedChains;
895     for (Chain &Chain : AllChains) {
896       if (Chain.blocks().size() > 0) {
897         SortedChains.push_back(&Chain);
898       }
899     }
900 
901     // Sorting chains by density in decreasing order
902     std::stable_sort(
903       SortedChains.begin(), SortedChains.end(),
904       [](const Chain *C1, const Chain *C2) {
905         // Original entry point to the front
906         if (C1->isEntryPoint() != C2->isEntryPoint()) {
907           if (C1->isEntryPoint())
908             return true;
909           if (C2->isEntryPoint())
910             return false;
911         }
912 
913         const double D1 = C1->density();
914         const double D2 = C2->density();
915         if (D1 != D2)
916           return D1 > D2;
917 
918         // Making the order deterministic
919         return C1->id() < C2->id();
920       }
921     );
922 
923     // Collect the basic blocks in the order specified by their chains
924     Order.reserve(BF.layout_size());
925     for (Chain *Chain : SortedChains) {
926       for (Block *Block : Chain->blocks()) {
927         Order.push_back(Block->BB);
928       }
929     }
930   }
931 
932 private:
933   // The binary function
934   const BinaryFunction &BF;
935 
936   // All CFG nodes (basic blocks)
937   std::vector<Block> AllBlocks;
938 
939   // All chains of blocks
940   std::vector<Chain> AllChains;
941 
942   // Active chains. The vector gets updated at runtime when chains are merged
943   std::vector<Chain *> HotChains;
944 
945   // All edges between chains
946   std::vector<Edge> AllEdges;
947 };
948 
949 void ExtTSPReorderAlgorithm::reorderBasicBlocks(
950       const BinaryFunction &BF, BasicBlockOrder &Order) const {
951   if (BF.layout_empty())
952     return;
953 
954   // Do not change layout of functions w/o profile information
955   if (!BF.hasValidProfile() || BF.layout_size() <= 2) {
956     for (BinaryBasicBlock *BB : BF.layout()) {
957       Order.push_back(BB);
958     }
959     return;
960   }
961 
962   // Apply the algorithm
963   ExtTSP(BF).run(Order);
964 
965   // Verify correctness
966   assert(Order[0]->isEntryPoint() && "Original entry point is not preserved");
967   assert(Order.size() == BF.layout_size() && "Wrong size of reordered layout");
968 }
969 
970 } // namespace bolt
971 } // namespace llvm
972