1 //===- bolt/Passes/ExtTSPReorderAlgorithm.cpp - Order basic blocks --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // ExtTSP - layout of basic blocks with i-cache optimization.
10 //
11 // The algorithm is a greedy heuristic that works with chains (ordered lists)
12 // of basic blocks. Initially all chains are isolated basic blocks. On every
13 // iteration, we pick a pair of chains whose merging yields the biggest increase
14 // in the ExtTSP value, which models how i-cache "friendly" a specific chain is.
15 // A pair of chains giving the maximum gain is merged into a new chain. The
16 // procedure stops when there is only one chain left, or when merging does not
17 // increase ExtTSP. In the latter case, the remaining chains are sorted by
18 // density in decreasing order.
19 //
20 // An important aspect is the way two chains are merged. Unlike earlier
21 // algorithms (e.g., OptimizeCacheReorderAlgorithm or Pettis-Hansen), two
22 // chains, X and Y, are first split into three, X1, X2, and Y. Then we
23 // consider all possible ways of gluing the three chains (e.g., X1YX2, X1X2Y,
24 // X2X1Y, X2YX1, YX1X2, YX2X1) and choose the one producing the largest score.
25 // This improves the quality of the final result (the search space is larger)
26 // while keeping the implementation sufficiently fast.
27 //
28 // Reference:
29 //   * A. Newell and S. Pupyrev, Improved Basic Block Reordering,
30 //     IEEE Transactions on Computers, 2020
31 //     https://arxiv.org/abs/1809.04676
32 //
33 //===----------------------------------------------------------------------===//
34 
35 #include "bolt/Core/BinaryBasicBlock.h"
36 #include "bolt/Core/BinaryFunction.h"
37 #include "bolt/Passes/ReorderAlgorithm.h"
38 #include "llvm/Support/CommandLine.h"
39 
40 using namespace llvm;
41 using namespace bolt;
42 
43 namespace opts {
44 
45 extern cl::OptionCategory BoltOptCategory;
46 extern cl::opt<bool> NoThreads;
47 
48 cl::opt<unsigned>
49 ChainSplitThreshold("chain-split-threshold",
50   cl::desc("The maximum size of a chain to apply splitting"),
51   cl::init(128),
52   cl::ReallyHidden,
53   cl::ZeroOrMore,
54   cl::cat(BoltOptCategory));
55 
56 cl::opt<double>
57 ForwardWeight("forward-weight",
58   cl::desc("The weight of forward jumps for ExtTSP value"),
59   cl::init(0.1),
60   cl::ReallyHidden,
61   cl::ZeroOrMore,
62   cl::cat(BoltOptCategory));
63 
64 cl::opt<double>
65 BackwardWeight("backward-weight",
66   cl::desc("The weight of backward jumps for ExtTSP value"),
67   cl::init(0.1),
68   cl::ReallyHidden,
69   cl::ZeroOrMore,
70   cl::cat(BoltOptCategory));
71 
72 cl::opt<unsigned>
73 ForwardDistance("forward-distance",
74   cl::desc("The maximum distance (in bytes) of forward jumps for ExtTSP value"),
75   cl::init(1024),
76   cl::ReallyHidden,
77   cl::ZeroOrMore,
78   cl::cat(BoltOptCategory));
79 
80 cl::opt<unsigned>
81 BackwardDistance("backward-distance",
82   cl::desc("The maximum distance (in bytes) of backward jumps for ExtTSP value"),
83   cl::init(640),
84   cl::ReallyHidden,
85   cl::ZeroOrMore,
86   cl::cat(BoltOptCategory));
87 
88 }
89 
90 namespace llvm {
91 namespace bolt {
92 
93 // Epsilon for comparison of doubles
94 constexpr double EPS = 1e-8;
95 
96 class Block;
97 class Chain;
98 class Edge;
99 
100 // Calculate Ext-TSP value, which quantifies the expected number of i-cache
101 // misses for a given ordering of basic blocks
102 double extTSPScore(uint64_t SrcAddr, uint64_t SrcSize, uint64_t DstAddr,
103                    uint64_t Count) {
104   assert(Count != BinaryBasicBlock::COUNT_NO_PROFILE);
105 
106   // Fallthrough
107   if (SrcAddr + SrcSize == DstAddr) {
108     // Assume that FallthroughWeight = 1.0 after normalization
109     return static_cast<double>(Count);
110   }
111   // Forward
112   if (SrcAddr + SrcSize < DstAddr) {
113     const uint64_t Dist = DstAddr - (SrcAddr + SrcSize);
114     if (Dist <= opts::ForwardDistance) {
115       double Prob = 1.0 - static_cast<double>(Dist) / opts::ForwardDistance;
116       return opts::ForwardWeight * Prob * Count;
117     }
118     return 0;
119   }
120   // Backward
121   const uint64_t Dist = SrcAddr + SrcSize - DstAddr;
122   if (Dist <= opts::BackwardDistance) {
123     double Prob = 1.0 - static_cast<double>(Dist) / opts::BackwardDistance;
124     return opts::BackwardWeight * Prob * Count;
125   }
126   return 0;
127 }
128 
129 using BlockPair = std::pair<Block *, Block *>;
130 using JumpList = std::vector<std::pair<BlockPair, uint64_t>>;
131 using BlockIter = std::vector<Block *>::const_iterator;
132 
133 enum MergeTypeTy {
134   X_Y = 0,
135   X1_Y_X2 = 1,
136   Y_X2_X1 = 2,
137   X2_X1_Y = 3,
138 };
139 
140 class MergeGainTy {
141 public:
142   explicit MergeGainTy() {}
143   explicit MergeGainTy(double Score, size_t MergeOffset, MergeTypeTy MergeType)
144       : Score(Score), MergeOffset(MergeOffset), MergeType(MergeType) {}
145 
146   double score() const { return Score; }
147 
148   size_t mergeOffset() const { return MergeOffset; }
149 
150   MergeTypeTy mergeType() const { return MergeType; }
151 
152   // returns 'true' iff Other is preferred over this
153   bool operator<(const MergeGainTy &Other) const {
154     return (Other.Score > EPS && Other.Score > Score + EPS);
155   }
156 
157 private:
158   double Score{-1.0};
159   size_t MergeOffset{0};
160   MergeTypeTy MergeType{MergeTypeTy::X_Y};
161 };
162 
163 // A node in CFG corresponding to a BinaryBasicBlock.
164 // The class wraps several mutable fields utilized in the ExtTSP algorithm
165 class Block {
166 public:
167   Block(const Block &) = delete;
168   Block(Block &&) = default;
169   Block &operator=(const Block &) = delete;
170   Block &operator=(Block &&) = default;
171 
172   // Corresponding basic block
173   BinaryBasicBlock *BB{nullptr};
174   // Current chain of the basic block
175   Chain *CurChain{nullptr};
176   // (Estimated) size of the block in the binary
177   uint64_t Size{0};
178   // Execution count of the block in the binary
179   uint64_t ExecutionCount{0};
180   // An original index of the node in CFG
181   size_t Index{0};
182   // The index of the block in the current chain
183   size_t CurIndex{0};
184   // An offset of the block in the current chain
185   mutable uint64_t EstimatedAddr{0};
186   // Fallthrough successor of the node in CFG
187   Block *FallthroughSucc{nullptr};
188   // Fallthrough predecessor of the node in CFG
189   Block *FallthroughPred{nullptr};
190   // Outgoing jumps from the block
191   std::vector<std::pair<Block *, uint64_t>> OutJumps;
192   // Incoming jumps to the block
193   std::vector<std::pair<Block *, uint64_t>> InJumps;
194   // Total execution count of incoming jumps
195   uint64_t InWeight{0};
196   // Total execution count of outgoing jumps
197   uint64_t OutWeight{0};
198 
199 public:
200   explicit Block(BinaryBasicBlock *BB_, uint64_t Size_)
201       : BB(BB_), Size(Size_), ExecutionCount(BB_->getKnownExecutionCount()),
202         Index(BB->getLayoutIndex()) {}
203 
204   bool adjacent(const Block *Other) const {
205     return hasOutJump(Other) || hasInJump(Other);
206   }
207 
208   bool hasOutJump(const Block *Other) const {
209     for (std::pair<Block *, uint64_t> Jump : OutJumps) {
210       if (Jump.first == Other)
211         return true;
212     }
213     return false;
214   }
215 
216   bool hasInJump(const Block *Other) const {
217     for (std::pair<Block *, uint64_t> Jump : InJumps) {
218       if (Jump.first == Other)
219         return true;
220     }
221     return false;
222   }
223 };
224 
225 // A chain (ordered sequence) of CFG nodes (basic blocks)
226 class Chain {
227 public:
228   Chain(const Chain &) = delete;
229   Chain(Chain &&) = default;
230   Chain &operator=(const Chain &) = delete;
231   Chain &operator=(Chain &&) = default;
232 
233   explicit Chain(size_t Id, Block *Block)
234       : Id(Id), IsEntry(Block->Index == 0),
235         ExecutionCount(Block->ExecutionCount), Size(Block->Size), Score(0),
236         Blocks(1, Block) {}
237 
238   size_t id() const { return Id; }
239 
240   uint64_t size() const { return Size; }
241 
242   double density() const { return static_cast<double>(ExecutionCount) / Size; }
243 
244   uint64_t executionCount() const { return ExecutionCount; }
245 
246   bool isEntryPoint() const { return IsEntry; }
247 
248   double score() const { return Score; }
249 
250   void setScore(double NewScore) { Score = NewScore; }
251 
252   const std::vector<Block *> &blocks() const { return Blocks; }
253 
254   const std::vector<std::pair<Chain *, Edge *>> &edges() const { return Edges; }
255 
256   Edge *getEdge(Chain *Other) const {
257     for (std::pair<Chain *, Edge *> It : Edges)
258       if (It.first == Other)
259         return It.second;
260     return nullptr;
261   }
262 
263   void removeEdge(Chain *Other) {
264     auto It = Edges.begin();
265     while (It != Edges.end()) {
266       if (It->first == Other) {
267         Edges.erase(It);
268         return;
269       }
270       It++;
271     }
272   }
273 
274   void addEdge(Chain *Other, Edge *Edge) { Edges.emplace_back(Other, Edge); }
275 
276   void merge(Chain *Other, const std::vector<Block *> &MergedBlocks) {
277     Blocks = MergedBlocks;
278     IsEntry |= Other->IsEntry;
279     ExecutionCount += Other->ExecutionCount;
280     Size += Other->Size;
281     // Update block's chains
282     for (size_t Idx = 0; Idx < Blocks.size(); Idx++) {
283       Blocks[Idx]->CurChain = this;
284       Blocks[Idx]->CurIndex = Idx;
285     }
286   }
287 
288   void mergeEdges(Chain *Other);
289 
290   void clear() {
291     Blocks.clear();
292     Edges.clear();
293   }
294 
295 private:
296   size_t Id;
297   bool IsEntry;
298   uint64_t ExecutionCount;
299   uint64_t Size;
300   // Cached ext-tsp score for the chain
301   double Score;
302   // Blocks of the chain
303   std::vector<Block *> Blocks;
304   // Adjacent chains and corresponding edges (lists of jumps)
305   std::vector<std::pair<Chain *, Edge *>> Edges;
306 };
307 
308 // An edge in CFG reprsenting jumps between chains of BinaryBasicBlocks.
309 // When blocks are merged into chains, the edges are combined too so that
310 // there is always at most one edge between a pair of chains
311 class Edge {
312 public:
313   Edge(const Edge &) = delete;
314   Edge(Edge &&) = default;
315   Edge &operator=(const Edge &) = delete;
316   Edge &operator=(Edge &&) = default;
317 
318   explicit Edge(Block *SrcBlock, Block *DstBlock, uint64_t EC)
319       : SrcChain(SrcBlock->CurChain), DstChain(DstBlock->CurChain),
320         Jumps(1, std::make_pair(std::make_pair(SrcBlock, DstBlock), EC)) {}
321 
322   const JumpList &jumps() const { return Jumps; }
323 
324   void changeEndpoint(Chain *From, Chain *To) {
325     if (From == SrcChain)
326       SrcChain = To;
327     if (From == DstChain)
328       DstChain = To;
329   }
330 
331   void appendJump(Block *SrcBlock, Block *DstBlock, uint64_t EC) {
332     Jumps.emplace_back(std::make_pair(SrcBlock, DstBlock), EC);
333   }
334 
335   void moveJumps(Edge *Other) {
336     Jumps.insert(Jumps.end(), Other->Jumps.begin(), Other->Jumps.end());
337     Other->Jumps.clear();
338   }
339 
340   bool hasCachedMergeGain(Chain *Src, Chain *Dst) const {
341     return Src == SrcChain ? CacheValidForward : CacheValidBackward;
342   }
343 
344   MergeGainTy getCachedMergeGain(Chain *Src, Chain *Dst) const {
345     return Src == SrcChain ? CachedGainForward : CachedGainBackward;
346   }
347 
348   void setCachedMergeGain(Chain *Src, Chain *Dst, MergeGainTy MergeGain) {
349     if (Src == SrcChain) {
350       CachedGainForward = MergeGain;
351       CacheValidForward = true;
352     } else {
353       CachedGainBackward = MergeGain;
354       CacheValidBackward = true;
355     }
356   }
357 
358   void invalidateCache() {
359     CacheValidForward = false;
360     CacheValidBackward = false;
361   }
362 
363 private:
364   Chain *SrcChain{nullptr};
365   Chain *DstChain{nullptr};
366   // Original jumps in the binary with correspinding execution counts
367   JumpList Jumps;
368   // Cached ext-tsp value for merging the pair of chains
369   // Since the gain of merging (Src, Dst) and (Dst, Src) might be different,
370   // we store both values here
371   MergeGainTy CachedGainForward;
372   MergeGainTy CachedGainBackward;
373   // Whether the cached value must be recomputed
374   bool CacheValidForward{false};
375   bool CacheValidBackward{false};
376 };
377 
378 void Chain::mergeEdges(Chain *Other) {
379   assert(this != Other && "cannot merge a chain with itself");
380 
381   // Update edges adjacent to chain Other
382   for (auto EdgeIt : Other->Edges) {
383     Chain *const DstChain = EdgeIt.first;
384     Edge *const DstEdge = EdgeIt.second;
385     Chain *const TargetChain = DstChain == Other ? this : DstChain;
386 
387     // Find the corresponding edge in the current chain
388     Edge *curEdge = getEdge(TargetChain);
389     if (curEdge == nullptr) {
390       DstEdge->changeEndpoint(Other, this);
391       this->addEdge(TargetChain, DstEdge);
392       if (DstChain != this && DstChain != Other)
393         DstChain->addEdge(this, DstEdge);
394     } else {
395       curEdge->moveJumps(DstEdge);
396     }
397     // Cleanup leftover edge
398     if (DstChain != Other)
399       DstChain->removeEdge(Other);
400   }
401 }
402 
403 // A wrapper around three chains of basic blocks; it is used to avoid extra
404 // instantiation of the vectors.
405 class MergedChain {
406 public:
407   MergedChain(BlockIter Begin1, BlockIter End1, BlockIter Begin2 = BlockIter(),
408               BlockIter End2 = BlockIter(), BlockIter Begin3 = BlockIter(),
409               BlockIter End3 = BlockIter())
410       : Begin1(Begin1), End1(End1), Begin2(Begin2), End2(End2), Begin3(Begin3),
411         End3(End3) {}
412 
413   template <typename F> void forEach(const F &Func) const {
414     for (auto It = Begin1; It != End1; It++)
415       Func(*It);
416     for (auto It = Begin2; It != End2; It++)
417       Func(*It);
418     for (auto It = Begin3; It != End3; It++)
419       Func(*It);
420   }
421 
422   std::vector<Block *> getBlocks() const {
423     std::vector<Block *> Result;
424     Result.reserve(std::distance(Begin1, End1) + std::distance(Begin2, End2) +
425                    std::distance(Begin3, End3));
426     Result.insert(Result.end(), Begin1, End1);
427     Result.insert(Result.end(), Begin2, End2);
428     Result.insert(Result.end(), Begin3, End3);
429     return Result;
430   }
431 
432   const Block *getFirstBlock() const { return *Begin1; }
433 
434 private:
435   BlockIter Begin1;
436   BlockIter End1;
437   BlockIter Begin2;
438   BlockIter End2;
439   BlockIter Begin3;
440   BlockIter End3;
441 };
442 
443 /// Deterministically compare pairs of chains
444 bool compareChainPairs(const Chain *A1, const Chain *B1, const Chain *A2,
445                        const Chain *B2) {
446   const uint64_t Samples1 = A1->executionCount() + B1->executionCount();
447   const uint64_t Samples2 = A2->executionCount() + B2->executionCount();
448   if (Samples1 != Samples2)
449     return Samples1 < Samples2;
450 
451   // Making the order deterministic
452   if (A1 != A2)
453     return A1->id() < A2->id();
454   return B1->id() < B2->id();
455 }
456 class ExtTSP {
457 public:
458   ExtTSP(const BinaryFunction &BF) : BF(BF) { initialize(); }
459 
460   /// Run the algorithm and return an ordering of basic block
461   void run(BinaryFunction::BasicBlockOrderType &Order) {
462     // Pass 1: Merge blocks with their fallthrough successors
463     mergeFallthroughs();
464 
465     // Pass 2: Merge pairs of chains while improving the ExtTSP objective
466     mergeChainPairs();
467 
468     // Pass 3: Merge cold blocks to reduce code size
469     mergeColdChains();
470 
471     // Collect blocks from all chains
472     concatChains(Order);
473   }
474 
475 private:
476   /// Initialize algorithm's data structures
477   void initialize() {
478     // Create a separate MCCodeEmitter to allow lock-free execution
479     BinaryContext::IndependentCodeEmitter Emitter;
480     if (!opts::NoThreads)
481       Emitter = BF.getBinaryContext().createIndependentMCCodeEmitter();
482 
483     // Initialize CFG nodes
484     AllBlocks.reserve(BF.layout_size());
485     size_t LayoutIndex = 0;
486     for (BinaryBasicBlock *BB : BF.layout()) {
487       BB->setLayoutIndex(LayoutIndex++);
488       uint64_t Size =
489           std::max<uint64_t>(BB->estimateSize(Emitter.MCE.get()), 1);
490       AllBlocks.emplace_back(BB, Size);
491     }
492 
493     // Initialize edges for the blocks and compute their total in/out weights
494     size_t NumEdges = 0;
495     for (Block &Block : AllBlocks) {
496       auto BI = Block.BB->branch_info_begin();
497       for (BinaryBasicBlock *SuccBB : Block.BB->successors()) {
498         assert(BI->Count != BinaryBasicBlock::COUNT_NO_PROFILE &&
499                "missing profile for a jump");
500         if (SuccBB != Block.BB && BI->Count > 0) {
501           class Block &SuccBlock = AllBlocks[SuccBB->getLayoutIndex()];
502           uint64_t Count = BI->Count;
503           SuccBlock.InWeight += Count;
504           SuccBlock.InJumps.emplace_back(&Block, Count);
505           Block.OutWeight += Count;
506           Block.OutJumps.emplace_back(&SuccBlock, Count);
507           NumEdges++;
508         }
509         ++BI;
510       }
511     }
512 
513     // Initialize execution count for every basic block, which is the
514     // maximum over the sums of all in and out edge weights.
515     // Also execution count of the entry point is set to at least 1
516     for (Block &Block : AllBlocks) {
517       size_t Index = Block.Index;
518       Block.ExecutionCount = std::max(Block.ExecutionCount, Block.InWeight);
519       Block.ExecutionCount = std::max(Block.ExecutionCount, Block.OutWeight);
520       if (Index == 0 && Block.ExecutionCount == 0)
521         Block.ExecutionCount = 1;
522     }
523 
524     // Initialize chains
525     AllChains.reserve(BF.layout_size());
526     HotChains.reserve(BF.layout_size());
527     for (Block &Block : AllBlocks) {
528       AllChains.emplace_back(Block.Index, &Block);
529       Block.CurChain = &AllChains.back();
530       if (Block.ExecutionCount > 0)
531         HotChains.push_back(&AllChains.back());
532     }
533 
534     // Initialize edges
535     AllEdges.reserve(NumEdges);
536     for (Block &Block : AllBlocks) {
537       for (std::pair<class Block *, uint64_t> &Jump : Block.OutJumps) {
538         class Block *const SuccBlock = Jump.first;
539         Edge *CurEdge = Block.CurChain->getEdge(SuccBlock->CurChain);
540         // this edge is already present in the graph
541         if (CurEdge != nullptr) {
542           assert(SuccBlock->CurChain->getEdge(Block.CurChain) != nullptr);
543           CurEdge->appendJump(&Block, SuccBlock, Jump.second);
544           continue;
545         }
546         // this is a new edge
547         AllEdges.emplace_back(&Block, SuccBlock, Jump.second);
548         Block.CurChain->addEdge(SuccBlock->CurChain, &AllEdges.back());
549         SuccBlock->CurChain->addEdge(Block.CurChain, &AllEdges.back());
550       }
551     }
552     assert(AllEdges.size() <= NumEdges && "Incorrect number of created edges");
553   }
554 
555   /// For a pair of blocks, A and B, block B is the fallthrough successor of A,
556   /// if (i) all jumps (based on profile) from A goes to B and (ii) all jumps
557   /// to B are from A. Such blocks should be adjacent in an optimal ordering;
558   /// the method finds and merges such pairs of blocks
559   void mergeFallthroughs() {
560     // Find fallthroughs based on edge weights
561     for (Block &Block : AllBlocks) {
562       if (Block.BB->succ_size() == 1 &&
563           Block.BB->getSuccessor()->pred_size() == 1 &&
564           Block.BB->getSuccessor()->getLayoutIndex() != 0) {
565         size_t SuccIndex = Block.BB->getSuccessor()->getLayoutIndex();
566         Block.FallthroughSucc = &AllBlocks[SuccIndex];
567         AllBlocks[SuccIndex].FallthroughPred = &Block;
568         continue;
569       }
570 
571       if (Block.OutWeight == 0)
572         continue;
573       for (std::pair<class Block *, uint64_t> &Edge : Block.OutJumps) {
574         class Block *const SuccBlock = Edge.first;
575         // Successor cannot be the first BB, which is pinned
576         if (Block.OutWeight == Edge.second &&
577             SuccBlock->InWeight == Edge.second && SuccBlock->Index != 0) {
578           Block.FallthroughSucc = SuccBlock;
579           SuccBlock->FallthroughPred = &Block;
580           break;
581         }
582       }
583     }
584 
585     // There might be 'cycles' in the fallthrough dependencies (since profile
586     // data isn't 100% accurate).
587     // Break the cycles by choosing the block with smallest index as the tail
588     for (Block &Block : AllBlocks) {
589       if (Block.FallthroughSucc == nullptr || Block.FallthroughPred == nullptr)
590         continue;
591 
592       class Block *SuccBlock = Block.FallthroughSucc;
593       while (SuccBlock != nullptr && SuccBlock != &Block)
594         SuccBlock = SuccBlock->FallthroughSucc;
595 
596       if (SuccBlock == nullptr)
597         continue;
598       // break the cycle
599       AllBlocks[Block.FallthroughPred->Index].FallthroughSucc = nullptr;
600       Block.FallthroughPred = nullptr;
601     }
602 
603     // Merge blocks with their fallthrough successors
604     for (Block &Block : AllBlocks) {
605       if (Block.FallthroughPred == nullptr &&
606           Block.FallthroughSucc != nullptr) {
607         class Block *CurBlock = &Block;
608         while (CurBlock->FallthroughSucc != nullptr) {
609           class Block *const NextBlock = CurBlock->FallthroughSucc;
610           mergeChains(Block.CurChain, NextBlock->CurChain, 0, MergeTypeTy::X_Y);
611           CurBlock = NextBlock;
612         }
613       }
614     }
615   }
616 
617   /// Merge pairs of chains while improving the ExtTSP objective
618   void mergeChainPairs() {
619     while (HotChains.size() > 1) {
620       Chain *BestChainPred = nullptr;
621       Chain *BestChainSucc = nullptr;
622       auto BestGain = MergeGainTy();
623       // Iterate over all pairs of chains
624       for (Chain *ChainPred : HotChains) {
625         // Get candidates for merging with the current chain
626         for (auto EdgeIter : ChainPred->edges()) {
627           Chain *ChainSucc = EdgeIter.first;
628           Edge *ChainEdge = EdgeIter.second;
629           // Ignore loop edges
630           if (ChainPred == ChainSucc)
631             continue;
632 
633           // Compute the gain of merging the two chains
634           MergeGainTy CurGain = mergeGain(ChainPred, ChainSucc, ChainEdge);
635           if (CurGain.score() <= EPS)
636             continue;
637 
638           if (BestGain < CurGain ||
639               (std::abs(CurGain.score() - BestGain.score()) < EPS &&
640                compareChainPairs(ChainPred, ChainSucc, BestChainPred,
641                                  BestChainSucc))) {
642             BestGain = CurGain;
643             BestChainPred = ChainPred;
644             BestChainSucc = ChainSucc;
645           }
646         }
647       }
648 
649       // Stop merging when there is no improvement
650       if (BestGain.score() <= EPS)
651         break;
652 
653       // Merge the best pair of chains
654       mergeChains(BestChainPred, BestChainSucc, BestGain.mergeOffset(),
655                   BestGain.mergeType());
656     }
657   }
658 
659   /// Merge cold blocks to reduce code size
660   void mergeColdChains() {
661     for (BinaryBasicBlock *SrcBB : BF.layout()) {
662       // Iterating in reverse order to make sure original fallthrough jumps are
663       // merged first
664       for (auto Itr = SrcBB->succ_rbegin(); Itr != SrcBB->succ_rend(); ++Itr) {
665         BinaryBasicBlock *DstBB = *Itr;
666         size_t SrcIndex = SrcBB->getLayoutIndex();
667         size_t DstIndex = DstBB->getLayoutIndex();
668         Chain *SrcChain = AllBlocks[SrcIndex].CurChain;
669         Chain *DstChain = AllBlocks[DstIndex].CurChain;
670         if (SrcChain != DstChain && !DstChain->isEntryPoint() &&
671             SrcChain->blocks().back()->Index == SrcIndex &&
672             DstChain->blocks().front()->Index == DstIndex)
673           mergeChains(SrcChain, DstChain, 0, MergeTypeTy::X_Y);
674       }
675     }
676   }
677 
678   /// Compute ExtTSP score for a given order of basic blocks
679   double score(const MergedChain &MergedBlocks, const JumpList &Jumps) const {
680     if (Jumps.empty())
681       return 0.0;
682     uint64_t CurAddr = 0;
683     MergedBlocks.forEach(
684       [&](const Block *BB) {
685         BB->EstimatedAddr = CurAddr;
686         CurAddr += BB->Size;
687       }
688     );
689 
690     double Score = 0;
691     for (const std::pair<std::pair<Block *, Block *>, uint64_t> &Jump : Jumps) {
692       const Block *SrcBlock = Jump.first.first;
693       const Block *DstBlock = Jump.first.second;
694       Score += extTSPScore(SrcBlock->EstimatedAddr, SrcBlock->Size,
695                            DstBlock->EstimatedAddr, Jump.second);
696     }
697     return Score;
698   }
699 
700   /// Compute the gain of merging two chains
701   ///
702   /// The function considers all possible ways of merging two chains and
703   /// computes the one having the largest increase in ExtTSP objective. The
704   /// result is a pair with the first element being the gain and the second
705   /// element being the corresponding merging type.
706   MergeGainTy mergeGain(Chain *ChainPred, Chain *ChainSucc, Edge *Edge) const {
707     if (Edge->hasCachedMergeGain(ChainPred, ChainSucc))
708       return Edge->getCachedMergeGain(ChainPred, ChainSucc);
709 
710     // Precompute jumps between ChainPred and ChainSucc
711     JumpList Jumps = Edge->jumps();
712     class Edge *EdgePP = ChainPred->getEdge(ChainPred);
713     if (EdgePP != nullptr)
714       Jumps.insert(Jumps.end(), EdgePP->jumps().begin(), EdgePP->jumps().end());
715     assert(Jumps.size() > 0 && "trying to merge chains w/o jumps");
716 
717     MergeGainTy Gain = MergeGainTy();
718     // Try to concatenate two chains w/o splitting
719     Gain = computeMergeGain(Gain, ChainPred, ChainSucc, Jumps, 0,
720                             MergeTypeTy::X_Y);
721 
722     // Try to break ChainPred in various ways and concatenate with ChainSucc
723     if (ChainPred->blocks().size() <= opts::ChainSplitThreshold) {
724       for (size_t Offset = 1; Offset < ChainPred->blocks().size(); Offset++) {
725         Block *BB1 = ChainPred->blocks()[Offset - 1];
726         Block *BB2 = ChainPred->blocks()[Offset];
727         // Does the splitting break FT successors?
728         if (BB1->FallthroughSucc != nullptr) {
729           (void)BB2;
730           assert(BB1->FallthroughSucc == BB2 && "Fallthrough not preserved");
731           continue;
732         }
733 
734         Gain = computeMergeGain(Gain, ChainPred, ChainSucc, Jumps, Offset,
735                                 MergeTypeTy::X1_Y_X2);
736         Gain = computeMergeGain(Gain, ChainPred, ChainSucc, Jumps, Offset,
737                                 MergeTypeTy::Y_X2_X1);
738         Gain = computeMergeGain(Gain, ChainPred, ChainSucc, Jumps, Offset,
739                                 MergeTypeTy::X2_X1_Y);
740       }
741     }
742 
743     Edge->setCachedMergeGain(ChainPred, ChainSucc, Gain);
744     return Gain;
745   }
746 
747   /// Merge two chains and update the best Gain
748   MergeGainTy computeMergeGain(const MergeGainTy &CurGain,
749                                const Chain *ChainPred, const Chain *ChainSucc,
750                                const JumpList &Jumps, size_t MergeOffset,
751                                MergeTypeTy MergeType) const {
752     MergedChain MergedBlocks = mergeBlocks(
753         ChainPred->blocks(), ChainSucc->blocks(), MergeOffset, MergeType);
754 
755     // Do not allow a merge that does not preserve the original entry block
756     if ((ChainPred->isEntryPoint() || ChainSucc->isEntryPoint()) &&
757         MergedBlocks.getFirstBlock()->Index != 0)
758       return CurGain;
759 
760     // The gain for the new chain
761     const double NewScore = score(MergedBlocks, Jumps) - ChainPred->score();
762     auto NewGain = MergeGainTy(NewScore, MergeOffset, MergeType);
763     return CurGain < NewGain ? NewGain : CurGain;
764   }
765 
766   /// Merge two chains of blocks respecting a given merge 'type' and 'offset'
767   ///
768   /// If MergeType == 0, then the result is a concatentation of two chains.
769   /// Otherwise, the first chain is cut into two sub-chains at the offset,
770   /// and merged using all possible ways of concatenating three chains.
771   MergedChain mergeBlocks(const std::vector<Block *> &X,
772                           const std::vector<Block *> &Y, size_t MergeOffset,
773                           MergeTypeTy MergeType) const {
774     // Split the first chain, X, into X1 and X2
775     BlockIter BeginX1 = X.begin();
776     BlockIter EndX1 = X.begin() + MergeOffset;
777     BlockIter BeginX2 = X.begin() + MergeOffset;
778     BlockIter EndX2 = X.end();
779     BlockIter BeginY = Y.begin();
780     BlockIter EndY = Y.end();
781 
782     // Construct a new chain from the three existing ones
783     switch (MergeType) {
784     case MergeTypeTy::X_Y:
785       return MergedChain(BeginX1, EndX2, BeginY, EndY);
786     case MergeTypeTy::X1_Y_X2:
787       return MergedChain(BeginX1, EndX1, BeginY, EndY, BeginX2, EndX2);
788     case MergeTypeTy::Y_X2_X1:
789       return MergedChain(BeginY, EndY, BeginX2, EndX2, BeginX1, EndX1);
790     case MergeTypeTy::X2_X1_Y:
791       return MergedChain(BeginX2, EndX2, BeginX1, EndX1, BeginY, EndY);
792     }
793 
794     llvm_unreachable("unexpected merge type");
795   }
796 
797   /// Merge chain From into chain Into, update the list of active chains,
798   /// adjacency information, and the corresponding cached values
799   void mergeChains(Chain *Into, Chain *From, size_t MergeOffset,
800                    MergeTypeTy MergeType) {
801     assert(Into != From && "a chain cannot be merged with itself");
802 
803     // Merge the blocks
804     MergedChain MergedBlocks =
805         mergeBlocks(Into->blocks(), From->blocks(), MergeOffset, MergeType);
806     Into->merge(From, MergedBlocks.getBlocks());
807     Into->mergeEdges(From);
808     From->clear();
809 
810     // Update cached ext-tsp score for the new chain
811     Edge *SelfEdge = Into->getEdge(Into);
812     if (SelfEdge != nullptr) {
813       MergedBlocks = MergedChain(Into->blocks().begin(), Into->blocks().end());
814       Into->setScore(score(MergedBlocks, SelfEdge->jumps()));
815     }
816 
817     // Remove chain From from the list of active chains
818     auto Iter = std::remove(HotChains.begin(), HotChains.end(), From);
819     HotChains.erase(Iter, HotChains.end());
820 
821     // Invalidate caches
822     for (std::pair<Chain *, Edge *> EdgeIter : Into->edges())
823       EdgeIter.second->invalidateCache();
824   }
825 
826   /// Concatenate all chains into a final order
827   void concatChains(BinaryFunction::BasicBlockOrderType &Order) {
828     // Collect chains
829     std::vector<Chain *> SortedChains;
830     for (Chain &Chain : AllChains)
831       if (Chain.blocks().size() > 0)
832         SortedChains.push_back(&Chain);
833 
834     // Sorting chains by density in decreasing order
835     std::stable_sort(
836       SortedChains.begin(), SortedChains.end(),
837       [](const Chain *C1, const Chain *C2) {
838         // Original entry point to the front
839         if (C1->isEntryPoint() != C2->isEntryPoint()) {
840           if (C1->isEntryPoint())
841             return true;
842           if (C2->isEntryPoint())
843             return false;
844         }
845 
846         const double D1 = C1->density();
847         const double D2 = C2->density();
848         if (D1 != D2)
849           return D1 > D2;
850 
851         // Making the order deterministic
852         return C1->id() < C2->id();
853       }
854     );
855 
856     // Collect the basic blocks in the order specified by their chains
857     Order.reserve(BF.layout_size());
858     for (Chain *Chain : SortedChains)
859       for (Block *Block : Chain->blocks())
860         Order.push_back(Block->BB);
861   }
862 
863 private:
864   // The binary function
865   const BinaryFunction &BF;
866 
867   // All CFG nodes (basic blocks)
868   std::vector<Block> AllBlocks;
869 
870   // All chains of blocks
871   std::vector<Chain> AllChains;
872 
873   // Active chains. The vector gets updated at runtime when chains are merged
874   std::vector<Chain *> HotChains;
875 
876   // All edges between chains
877   std::vector<Edge> AllEdges;
878 };
879 
880 void ExtTSPReorderAlgorithm::reorderBasicBlocks(const BinaryFunction &BF,
881                                                 BasicBlockOrder &Order) const {
882   if (BF.layout_empty())
883     return;
884 
885   // Do not change layout of functions w/o profile information
886   if (!BF.hasValidProfile() || BF.layout_size() <= 2) {
887     for (BinaryBasicBlock *BB : BF.layout())
888       Order.push_back(BB);
889     return;
890   }
891 
892   // Apply the algorithm
893   ExtTSP(BF).run(Order);
894 
895   // Verify correctness
896   assert(Order[0]->isEntryPoint() && "Original entry point is not preserved");
897   assert(Order.size() == BF.layout_size() && "Wrong size of reordered layout");
898 }
899 
900 } // namespace bolt
901 } // namespace llvm
902