1829037a9SAmara Emerson //===- SwitchLoweringUtils.cpp - Switch Lowering --------------------------===//
2829037a9SAmara Emerson //
3829037a9SAmara Emerson // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4829037a9SAmara Emerson // See https://llvm.org/LICENSE.txt for license information.
5829037a9SAmara Emerson // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6829037a9SAmara Emerson //
7829037a9SAmara Emerson //===----------------------------------------------------------------------===//
8829037a9SAmara Emerson //
9829037a9SAmara Emerson // This file contains switch inst lowering optimizations and utilities for
10829037a9SAmara Emerson // codegen, so that it can be used for both SelectionDAG and GlobalISel.
11829037a9SAmara Emerson //
12829037a9SAmara Emerson //===----------------------------------------------------------------------===//
13829037a9SAmara Emerson 
14829037a9SAmara Emerson #include "llvm/CodeGen/SwitchLoweringUtils.h"
151673a080SSimon Pilgrim #include "llvm/CodeGen/FunctionLoweringInfo.h"
161673a080SSimon Pilgrim #include "llvm/CodeGen/MachineJumpTableInfo.h"
17f42f733aSSimon Pilgrim #include "llvm/CodeGen/TargetLowering.h"
18fe0006c8SSimon Pilgrim #include "llvm/Target/TargetMachine.h"
19829037a9SAmara Emerson 
20829037a9SAmara Emerson using namespace llvm;
21829037a9SAmara Emerson using namespace SwitchCG;
22829037a9SAmara Emerson 
getJumpTableRange(const CaseClusterVector & Clusters,unsigned First,unsigned Last)233740ae3bSHans Wennborg uint64_t SwitchCG::getJumpTableRange(const CaseClusterVector &Clusters,
243740ae3bSHans Wennborg                                      unsigned First, unsigned Last) {
253740ae3bSHans Wennborg   assert(Last >= First);
263740ae3bSHans Wennborg   const APInt &LowCase = Clusters[First].Low->getValue();
273740ae3bSHans Wennborg   const APInt &HighCase = Clusters[Last].High->getValue();
283740ae3bSHans Wennborg   assert(LowCase.getBitWidth() == HighCase.getBitWidth());
29829037a9SAmara Emerson 
303740ae3bSHans Wennborg   // FIXME: A range of consecutive cases has 100% density, but only requires one
313740ae3bSHans Wennborg   // comparison to lower. We should discriminate against such consecutive ranges
323740ae3bSHans Wennborg   // in jump tables.
333740ae3bSHans Wennborg   return (HighCase - LowCase).getLimitedValue((UINT64_MAX - 1) / 100) + 1;
34829037a9SAmara Emerson }
35829037a9SAmara Emerson 
363740ae3bSHans Wennborg uint64_t
getJumpTableNumCases(const SmallVectorImpl<unsigned> & TotalCases,unsigned First,unsigned Last)373740ae3bSHans Wennborg SwitchCG::getJumpTableNumCases(const SmallVectorImpl<unsigned> &TotalCases,
383740ae3bSHans Wennborg                                unsigned First, unsigned Last) {
393740ae3bSHans Wennborg   assert(Last >= First);
403740ae3bSHans Wennborg   assert(TotalCases[Last] >= TotalCases[First]);
413740ae3bSHans Wennborg   uint64_t NumCases =
423740ae3bSHans Wennborg       TotalCases[Last] - (First == 0 ? 0 : TotalCases[First - 1]);
433740ae3bSHans Wennborg   return NumCases;
44829037a9SAmara Emerson }
45829037a9SAmara Emerson 
findJumpTables(CaseClusterVector & Clusters,const SwitchInst * SI,MachineBasicBlock * DefaultMBB,ProfileSummaryInfo * PSI,BlockFrequencyInfo * BFI)46829037a9SAmara Emerson void SwitchCG::SwitchLowering::findJumpTables(CaseClusterVector &Clusters,
47829037a9SAmara Emerson                                               const SwitchInst *SI,
480d987e41SHiroshi Yamauchi                                               MachineBasicBlock *DefaultMBB,
490d987e41SHiroshi Yamauchi                                               ProfileSummaryInfo *PSI,
500d987e41SHiroshi Yamauchi                                               BlockFrequencyInfo *BFI) {
51829037a9SAmara Emerson #ifndef NDEBUG
52829037a9SAmara Emerson   // Clusters must be non-empty, sorted, and only contain Range clusters.
53829037a9SAmara Emerson   assert(!Clusters.empty());
54829037a9SAmara Emerson   for (CaseCluster &C : Clusters)
55829037a9SAmara Emerson     assert(C.Kind == CC_Range);
56829037a9SAmara Emerson   for (unsigned i = 1, e = Clusters.size(); i < e; ++i)
57829037a9SAmara Emerson     assert(Clusters[i - 1].High->getValue().slt(Clusters[i].Low->getValue()));
58829037a9SAmara Emerson #endif
59829037a9SAmara Emerson 
60fe4625fbSAmara Emerson   assert(TLI && "TLI not set!");
61829037a9SAmara Emerson   if (!TLI->areJTsAllowed(SI->getParent()->getParent()))
62829037a9SAmara Emerson     return;
63829037a9SAmara Emerson 
64829037a9SAmara Emerson   const unsigned MinJumpTableEntries = TLI->getMinimumJumpTableEntries();
65829037a9SAmara Emerson   const unsigned SmallNumberOfEntries = MinJumpTableEntries / 2;
66829037a9SAmara Emerson 
67aa10f050SEvandro Menezes   // Bail if not enough cases.
68aa10f050SEvandro Menezes   const int64_t N = Clusters.size();
69829037a9SAmara Emerson   if (N < 2 || N < MinJumpTableEntries)
70829037a9SAmara Emerson     return;
71829037a9SAmara Emerson 
723740ae3bSHans Wennborg   // Accumulated number of cases in each cluster and those prior to it.
733740ae3bSHans Wennborg   SmallVector<unsigned, 8> TotalCases(N);
743740ae3bSHans Wennborg   for (unsigned i = 0; i < N; ++i) {
753740ae3bSHans Wennborg     const APInt &Hi = Clusters[i].High->getValue();
763740ae3bSHans Wennborg     const APInt &Lo = Clusters[i].Low->getValue();
773740ae3bSHans Wennborg     TotalCases[i] = (Hi - Lo).getLimitedValue() + 1;
783740ae3bSHans Wennborg     if (i != 0)
793740ae3bSHans Wennborg       TotalCases[i] += TotalCases[i - 1];
803740ae3bSHans Wennborg   }
813740ae3bSHans Wennborg 
823740ae3bSHans Wennborg   uint64_t Range = getJumpTableRange(Clusters,0, N - 1);
833740ae3bSHans Wennborg   uint64_t NumCases = getJumpTableNumCases(TotalCases, 0, N - 1);
843740ae3bSHans Wennborg   assert(NumCases < UINT64_MAX / 100);
853740ae3bSHans Wennborg   assert(Range >= NumCases);
86aa10f050SEvandro Menezes 
87aa10f050SEvandro Menezes   // Cheap case: the whole range may be suitable for jump table.
880d987e41SHiroshi Yamauchi   if (TLI->isSuitableForJumpTable(SI, NumCases, Range, PSI, BFI)) {
89829037a9SAmara Emerson     CaseCluster JTCluster;
90829037a9SAmara Emerson     if (buildJumpTable(Clusters, 0, N - 1, SI, DefaultMBB, JTCluster)) {
91829037a9SAmara Emerson       Clusters[0] = JTCluster;
92829037a9SAmara Emerson       Clusters.resize(1);
93829037a9SAmara Emerson       return;
94829037a9SAmara Emerson     }
95829037a9SAmara Emerson   }
96829037a9SAmara Emerson 
97829037a9SAmara Emerson   // The algorithm below is not suitable for -O0.
98829037a9SAmara Emerson   if (TM->getOptLevel() == CodeGenOpt::None)
99829037a9SAmara Emerson     return;
100829037a9SAmara Emerson 
101829037a9SAmara Emerson   // Split Clusters into minimum number of dense partitions. The algorithm uses
102829037a9SAmara Emerson   // the same idea as Kannan & Proebsting "Correction to 'Producing Good Code
103829037a9SAmara Emerson   // for the Case Statement'" (1994), but builds the MinPartitions array in
104829037a9SAmara Emerson   // reverse order to make it easier to reconstruct the partitions in ascending
105829037a9SAmara Emerson   // order. In the choice between two optimal partitionings, it picks the one
106829037a9SAmara Emerson   // which yields more jump tables.
107829037a9SAmara Emerson 
108829037a9SAmara Emerson   // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1].
109829037a9SAmara Emerson   SmallVector<unsigned, 8> MinPartitions(N);
110829037a9SAmara Emerson   // LastElement[i] is the last element of the partition starting at i.
111829037a9SAmara Emerson   SmallVector<unsigned, 8> LastElement(N);
1123740ae3bSHans Wennborg   // PartitionsScore[i] is used to break ties when choosing between two
1133740ae3bSHans Wennborg   // partitionings resulting in the same number of partitions.
1143740ae3bSHans Wennborg   SmallVector<unsigned, 8> PartitionsScore(N);
115829037a9SAmara Emerson   // For PartitionsScore, a small number of comparisons is considered as good as
116829037a9SAmara Emerson   // a jump table and a single comparison is considered better than a jump
117829037a9SAmara Emerson   // table.
118829037a9SAmara Emerson   enum PartitionScores : unsigned {
119829037a9SAmara Emerson     NoTable = 0,
120829037a9SAmara Emerson     Table = 1,
121829037a9SAmara Emerson     FewCases = 1,
122829037a9SAmara Emerson     SingleCase = 2
123829037a9SAmara Emerson   };
124829037a9SAmara Emerson 
125829037a9SAmara Emerson   // Base case: There is only one way to partition Clusters[N-1].
126829037a9SAmara Emerson   MinPartitions[N - 1] = 1;
127829037a9SAmara Emerson   LastElement[N - 1] = N - 1;
128829037a9SAmara Emerson   PartitionsScore[N - 1] = PartitionScores::SingleCase;
129829037a9SAmara Emerson 
130829037a9SAmara Emerson   // Note: loop indexes are signed to avoid underflow.
131829037a9SAmara Emerson   for (int64_t i = N - 2; i >= 0; i--) {
132829037a9SAmara Emerson     // Find optimal partitioning of Clusters[i..N-1].
133829037a9SAmara Emerson     // Baseline: Put Clusters[i] into a partition on its own.
134829037a9SAmara Emerson     MinPartitions[i] = MinPartitions[i + 1] + 1;
135829037a9SAmara Emerson     LastElement[i] = i;
136829037a9SAmara Emerson     PartitionsScore[i] = PartitionsScore[i + 1] + PartitionScores::SingleCase;
137829037a9SAmara Emerson 
138829037a9SAmara Emerson     // Search for a solution that results in fewer partitions.
139829037a9SAmara Emerson     for (int64_t j = N - 1; j > i; j--) {
140829037a9SAmara Emerson       // Try building a partition from Clusters[i..j].
1413740ae3bSHans Wennborg       Range = getJumpTableRange(Clusters, i, j);
1423740ae3bSHans Wennborg       NumCases = getJumpTableNumCases(TotalCases, i, j);
1433740ae3bSHans Wennborg       assert(NumCases < UINT64_MAX / 100);
1443740ae3bSHans Wennborg       assert(Range >= NumCases);
1453740ae3bSHans Wennborg 
1460d987e41SHiroshi Yamauchi       if (TLI->isSuitableForJumpTable(SI, NumCases, Range, PSI, BFI)) {
147829037a9SAmara Emerson         unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);
148829037a9SAmara Emerson         unsigned Score = j == N - 1 ? 0 : PartitionsScore[j + 1];
149829037a9SAmara Emerson         int64_t NumEntries = j - i + 1;
150829037a9SAmara Emerson 
151829037a9SAmara Emerson         if (NumEntries == 1)
152829037a9SAmara Emerson           Score += PartitionScores::SingleCase;
153829037a9SAmara Emerson         else if (NumEntries <= SmallNumberOfEntries)
154829037a9SAmara Emerson           Score += PartitionScores::FewCases;
155829037a9SAmara Emerson         else if (NumEntries >= MinJumpTableEntries)
156829037a9SAmara Emerson           Score += PartitionScores::Table;
157829037a9SAmara Emerson 
158829037a9SAmara Emerson         // If this leads to fewer partitions, or to the same number of
159829037a9SAmara Emerson         // partitions with better score, it is a better partitioning.
160829037a9SAmara Emerson         if (NumPartitions < MinPartitions[i] ||
161829037a9SAmara Emerson             (NumPartitions == MinPartitions[i] && Score > PartitionsScore[i])) {
162829037a9SAmara Emerson           MinPartitions[i] = NumPartitions;
163829037a9SAmara Emerson           LastElement[i] = j;
164829037a9SAmara Emerson           PartitionsScore[i] = Score;
165829037a9SAmara Emerson         }
166829037a9SAmara Emerson       }
167829037a9SAmara Emerson     }
168829037a9SAmara Emerson   }
169829037a9SAmara Emerson 
170829037a9SAmara Emerson   // Iterate over the partitions, replacing some with jump tables in-place.
171829037a9SAmara Emerson   unsigned DstIndex = 0;
172829037a9SAmara Emerson   for (unsigned First = 0, Last; First < N; First = Last + 1) {
173829037a9SAmara Emerson     Last = LastElement[First];
174829037a9SAmara Emerson     assert(Last >= First);
175829037a9SAmara Emerson     assert(DstIndex <= First);
176829037a9SAmara Emerson     unsigned NumClusters = Last - First + 1;
177829037a9SAmara Emerson 
178829037a9SAmara Emerson     CaseCluster JTCluster;
179829037a9SAmara Emerson     if (NumClusters >= MinJumpTableEntries &&
180829037a9SAmara Emerson         buildJumpTable(Clusters, First, Last, SI, DefaultMBB, JTCluster)) {
181829037a9SAmara Emerson       Clusters[DstIndex++] = JTCluster;
182829037a9SAmara Emerson     } else {
183829037a9SAmara Emerson       for (unsigned I = First; I <= Last; ++I)
184829037a9SAmara Emerson         std::memmove(&Clusters[DstIndex++], &Clusters[I], sizeof(Clusters[I]));
185829037a9SAmara Emerson     }
186829037a9SAmara Emerson   }
187829037a9SAmara Emerson   Clusters.resize(DstIndex);
188829037a9SAmara Emerson }
189829037a9SAmara Emerson 
buildJumpTable(const CaseClusterVector & Clusters,unsigned First,unsigned Last,const SwitchInst * SI,MachineBasicBlock * DefaultMBB,CaseCluster & JTCluster)190829037a9SAmara Emerson bool SwitchCG::SwitchLowering::buildJumpTable(const CaseClusterVector &Clusters,
191829037a9SAmara Emerson                                               unsigned First, unsigned Last,
192829037a9SAmara Emerson                                               const SwitchInst *SI,
193829037a9SAmara Emerson                                               MachineBasicBlock *DefaultMBB,
194829037a9SAmara Emerson                                               CaseCluster &JTCluster) {
195829037a9SAmara Emerson   assert(First <= Last);
196829037a9SAmara Emerson 
197829037a9SAmara Emerson   auto Prob = BranchProbability::getZero();
198829037a9SAmara Emerson   unsigned NumCmps = 0;
199829037a9SAmara Emerson   std::vector<MachineBasicBlock*> Table;
200829037a9SAmara Emerson   DenseMap<MachineBasicBlock*, BranchProbability> JTProbs;
201829037a9SAmara Emerson 
202829037a9SAmara Emerson   // Initialize probabilities in JTProbs.
203829037a9SAmara Emerson   for (unsigned I = First; I <= Last; ++I)
204829037a9SAmara Emerson     JTProbs[Clusters[I].MBB] = BranchProbability::getZero();
205829037a9SAmara Emerson 
206829037a9SAmara Emerson   for (unsigned I = First; I <= Last; ++I) {
207829037a9SAmara Emerson     assert(Clusters[I].Kind == CC_Range);
208829037a9SAmara Emerson     Prob += Clusters[I].Prob;
209829037a9SAmara Emerson     const APInt &Low = Clusters[I].Low->getValue();
210829037a9SAmara Emerson     const APInt &High = Clusters[I].High->getValue();
211829037a9SAmara Emerson     NumCmps += (Low == High) ? 1 : 2;
212829037a9SAmara Emerson     if (I != First) {
213829037a9SAmara Emerson       // Fill the gap between this and the previous cluster.
214829037a9SAmara Emerson       const APInt &PreviousHigh = Clusters[I - 1].High->getValue();
215829037a9SAmara Emerson       assert(PreviousHigh.slt(Low));
216829037a9SAmara Emerson       uint64_t Gap = (Low - PreviousHigh).getLimitedValue() - 1;
217829037a9SAmara Emerson       for (uint64_t J = 0; J < Gap; J++)
218829037a9SAmara Emerson         Table.push_back(DefaultMBB);
219829037a9SAmara Emerson     }
220829037a9SAmara Emerson     uint64_t ClusterSize = (High - Low).getLimitedValue() + 1;
221829037a9SAmara Emerson     for (uint64_t J = 0; J < ClusterSize; ++J)
222829037a9SAmara Emerson       Table.push_back(Clusters[I].MBB);
223829037a9SAmara Emerson     JTProbs[Clusters[I].MBB] += Clusters[I].Prob;
224829037a9SAmara Emerson   }
225829037a9SAmara Emerson 
226829037a9SAmara Emerson   unsigned NumDests = JTProbs.size();
227829037a9SAmara Emerson   if (TLI->isSuitableForBitTests(NumDests, NumCmps,
228829037a9SAmara Emerson                                  Clusters[First].Low->getValue(),
229829037a9SAmara Emerson                                  Clusters[Last].High->getValue(), *DL)) {
230829037a9SAmara Emerson     // Clusters[First..Last] should be lowered as bit tests instead.
231829037a9SAmara Emerson     return false;
232829037a9SAmara Emerson   }
233829037a9SAmara Emerson 
234829037a9SAmara Emerson   // Create the MBB that will load from and jump through the table.
235829037a9SAmara Emerson   // Note: We create it here, but it's not inserted into the function yet.
236829037a9SAmara Emerson   MachineFunction *CurMF = FuncInfo.MF;
237829037a9SAmara Emerson   MachineBasicBlock *JumpTableMBB =
238829037a9SAmara Emerson       CurMF->CreateMachineBasicBlock(SI->getParent());
239829037a9SAmara Emerson 
240829037a9SAmara Emerson   // Add successors. Note: use table order for determinism.
241829037a9SAmara Emerson   SmallPtrSet<MachineBasicBlock *, 8> Done;
242829037a9SAmara Emerson   for (MachineBasicBlock *Succ : Table) {
243829037a9SAmara Emerson     if (Done.count(Succ))
244829037a9SAmara Emerson       continue;
245829037a9SAmara Emerson     addSuccessorWithProb(JumpTableMBB, Succ, JTProbs[Succ]);
246829037a9SAmara Emerson     Done.insert(Succ);
247829037a9SAmara Emerson   }
248829037a9SAmara Emerson   JumpTableMBB->normalizeSuccProbs();
249829037a9SAmara Emerson 
250829037a9SAmara Emerson   unsigned JTI = CurMF->getOrCreateJumpTableInfo(TLI->getJumpTableEncoding())
251829037a9SAmara Emerson                      ->createJumpTableIndex(Table);
252829037a9SAmara Emerson 
253829037a9SAmara Emerson   // Set up the jump table info.
254829037a9SAmara Emerson   JumpTable JT(-1U, JTI, JumpTableMBB, nullptr);
255829037a9SAmara Emerson   JumpTableHeader JTH(Clusters[First].Low->getValue(),
256829037a9SAmara Emerson                       Clusters[Last].High->getValue(), SI->getCondition(),
257829037a9SAmara Emerson                       nullptr, false);
258829037a9SAmara Emerson   JTCases.emplace_back(std::move(JTH), std::move(JT));
259829037a9SAmara Emerson 
260829037a9SAmara Emerson   JTCluster = CaseCluster::jumpTable(Clusters[First].Low, Clusters[Last].High,
261829037a9SAmara Emerson                                      JTCases.size() - 1, Prob);
262829037a9SAmara Emerson   return true;
263829037a9SAmara Emerson }
264829037a9SAmara Emerson 
findBitTestClusters(CaseClusterVector & Clusters,const SwitchInst * SI)265829037a9SAmara Emerson void SwitchCG::SwitchLowering::findBitTestClusters(CaseClusterVector &Clusters,
266829037a9SAmara Emerson                                                    const SwitchInst *SI) {
267829037a9SAmara Emerson   // Partition Clusters into as few subsets as possible, where each subset has a
268829037a9SAmara Emerson   // range that fits in a machine word and has <= 3 unique destinations.
269829037a9SAmara Emerson 
270829037a9SAmara Emerson #ifndef NDEBUG
271829037a9SAmara Emerson   // Clusters must be sorted and contain Range or JumpTable clusters.
272829037a9SAmara Emerson   assert(!Clusters.empty());
273829037a9SAmara Emerson   assert(Clusters[0].Kind == CC_Range || Clusters[0].Kind == CC_JumpTable);
274829037a9SAmara Emerson   for (const CaseCluster &C : Clusters)
275829037a9SAmara Emerson     assert(C.Kind == CC_Range || C.Kind == CC_JumpTable);
276829037a9SAmara Emerson   for (unsigned i = 1; i < Clusters.size(); ++i)
277829037a9SAmara Emerson     assert(Clusters[i-1].High->getValue().slt(Clusters[i].Low->getValue()));
278829037a9SAmara Emerson #endif
279829037a9SAmara Emerson 
280829037a9SAmara Emerson   // The algorithm below is not suitable for -O0.
281829037a9SAmara Emerson   if (TM->getOptLevel() == CodeGenOpt::None)
282829037a9SAmara Emerson     return;
283829037a9SAmara Emerson 
284829037a9SAmara Emerson   // If target does not have legal shift left, do not emit bit tests at all.
285829037a9SAmara Emerson   EVT PTy = TLI->getPointerTy(*DL);
286829037a9SAmara Emerson   if (!TLI->isOperationLegal(ISD::SHL, PTy))
287829037a9SAmara Emerson     return;
288829037a9SAmara Emerson 
289829037a9SAmara Emerson   int BitWidth = PTy.getSizeInBits();
290829037a9SAmara Emerson   const int64_t N = Clusters.size();
291829037a9SAmara Emerson 
292829037a9SAmara Emerson   // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1].
293829037a9SAmara Emerson   SmallVector<unsigned, 8> MinPartitions(N);
294829037a9SAmara Emerson   // LastElement[i] is the last element of the partition starting at i.
295829037a9SAmara Emerson   SmallVector<unsigned, 8> LastElement(N);
296829037a9SAmara Emerson 
297829037a9SAmara Emerson   // FIXME: This might not be the best algorithm for finding bit test clusters.
298829037a9SAmara Emerson 
299829037a9SAmara Emerson   // Base case: There is only one way to partition Clusters[N-1].
300829037a9SAmara Emerson   MinPartitions[N - 1] = 1;
301829037a9SAmara Emerson   LastElement[N - 1] = N - 1;
302829037a9SAmara Emerson 
303829037a9SAmara Emerson   // Note: loop indexes are signed to avoid underflow.
304829037a9SAmara Emerson   for (int64_t i = N - 2; i >= 0; --i) {
305829037a9SAmara Emerson     // Find optimal partitioning of Clusters[i..N-1].
306829037a9SAmara Emerson     // Baseline: Put Clusters[i] into a partition on its own.
307829037a9SAmara Emerson     MinPartitions[i] = MinPartitions[i + 1] + 1;
308829037a9SAmara Emerson     LastElement[i] = i;
309829037a9SAmara Emerson 
310829037a9SAmara Emerson     // Search for a solution that results in fewer partitions.
311829037a9SAmara Emerson     // Note: the search is limited by BitWidth, reducing time complexity.
312829037a9SAmara Emerson     for (int64_t j = std::min(N - 1, i + BitWidth - 1); j > i; --j) {
313829037a9SAmara Emerson       // Try building a partition from Clusters[i..j].
314829037a9SAmara Emerson 
315829037a9SAmara Emerson       // Check the range.
316829037a9SAmara Emerson       if (!TLI->rangeFitsInWord(Clusters[i].Low->getValue(),
317829037a9SAmara Emerson                                 Clusters[j].High->getValue(), *DL))
318829037a9SAmara Emerson         continue;
319829037a9SAmara Emerson 
320829037a9SAmara Emerson       // Check nbr of destinations and cluster types.
321829037a9SAmara Emerson       // FIXME: This works, but doesn't seem very efficient.
322829037a9SAmara Emerson       bool RangesOnly = true;
323829037a9SAmara Emerson       BitVector Dests(FuncInfo.MF->getNumBlockIDs());
324829037a9SAmara Emerson       for (int64_t k = i; k <= j; k++) {
325829037a9SAmara Emerson         if (Clusters[k].Kind != CC_Range) {
326829037a9SAmara Emerson           RangesOnly = false;
327829037a9SAmara Emerson           break;
328829037a9SAmara Emerson         }
329829037a9SAmara Emerson         Dests.set(Clusters[k].MBB->getNumber());
330829037a9SAmara Emerson       }
331829037a9SAmara Emerson       if (!RangesOnly || Dests.count() > 3)
332829037a9SAmara Emerson         break;
333829037a9SAmara Emerson 
334829037a9SAmara Emerson       // Check if it's a better partition.
335829037a9SAmara Emerson       unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);
336829037a9SAmara Emerson       if (NumPartitions < MinPartitions[i]) {
337829037a9SAmara Emerson         // Found a better partition.
338829037a9SAmara Emerson         MinPartitions[i] = NumPartitions;
339829037a9SAmara Emerson         LastElement[i] = j;
340829037a9SAmara Emerson       }
341829037a9SAmara Emerson     }
342829037a9SAmara Emerson   }
343829037a9SAmara Emerson 
344829037a9SAmara Emerson   // Iterate over the partitions, replacing with bit-test clusters in-place.
345829037a9SAmara Emerson   unsigned DstIndex = 0;
346829037a9SAmara Emerson   for (unsigned First = 0, Last; First < N; First = Last + 1) {
347829037a9SAmara Emerson     Last = LastElement[First];
348829037a9SAmara Emerson     assert(First <= Last);
349829037a9SAmara Emerson     assert(DstIndex <= First);
350829037a9SAmara Emerson 
351829037a9SAmara Emerson     CaseCluster BitTestCluster;
352829037a9SAmara Emerson     if (buildBitTests(Clusters, First, Last, SI, BitTestCluster)) {
353829037a9SAmara Emerson       Clusters[DstIndex++] = BitTestCluster;
354829037a9SAmara Emerson     } else {
355829037a9SAmara Emerson       size_t NumClusters = Last - First + 1;
356829037a9SAmara Emerson       std::memmove(&Clusters[DstIndex], &Clusters[First],
357829037a9SAmara Emerson                    sizeof(Clusters[0]) * NumClusters);
358829037a9SAmara Emerson       DstIndex += NumClusters;
359829037a9SAmara Emerson     }
360829037a9SAmara Emerson   }
361829037a9SAmara Emerson   Clusters.resize(DstIndex);
362829037a9SAmara Emerson }
363829037a9SAmara Emerson 
buildBitTests(CaseClusterVector & Clusters,unsigned First,unsigned Last,const SwitchInst * SI,CaseCluster & BTCluster)364829037a9SAmara Emerson bool SwitchCG::SwitchLowering::buildBitTests(CaseClusterVector &Clusters,
365829037a9SAmara Emerson                                              unsigned First, unsigned Last,
366829037a9SAmara Emerson                                              const SwitchInst *SI,
367829037a9SAmara Emerson                                              CaseCluster &BTCluster) {
368829037a9SAmara Emerson   assert(First <= Last);
369829037a9SAmara Emerson   if (First == Last)
370829037a9SAmara Emerson     return false;
371829037a9SAmara Emerson 
372829037a9SAmara Emerson   BitVector Dests(FuncInfo.MF->getNumBlockIDs());
373829037a9SAmara Emerson   unsigned NumCmps = 0;
374829037a9SAmara Emerson   for (int64_t I = First; I <= Last; ++I) {
375829037a9SAmara Emerson     assert(Clusters[I].Kind == CC_Range);
376829037a9SAmara Emerson     Dests.set(Clusters[I].MBB->getNumber());
377829037a9SAmara Emerson     NumCmps += (Clusters[I].Low == Clusters[I].High) ? 1 : 2;
378829037a9SAmara Emerson   }
379829037a9SAmara Emerson   unsigned NumDests = Dests.count();
380829037a9SAmara Emerson 
381829037a9SAmara Emerson   APInt Low = Clusters[First].Low->getValue();
382829037a9SAmara Emerson   APInt High = Clusters[Last].High->getValue();
383829037a9SAmara Emerson   assert(Low.slt(High));
384829037a9SAmara Emerson 
385829037a9SAmara Emerson   if (!TLI->isSuitableForBitTests(NumDests, NumCmps, Low, High, *DL))
386829037a9SAmara Emerson     return false;
387829037a9SAmara Emerson 
388829037a9SAmara Emerson   APInt LowBound;
389829037a9SAmara Emerson   APInt CmpRange;
390829037a9SAmara Emerson 
391829037a9SAmara Emerson   const int BitWidth = TLI->getPointerTy(*DL).getSizeInBits();
392829037a9SAmara Emerson   assert(TLI->rangeFitsInWord(Low, High, *DL) &&
393829037a9SAmara Emerson          "Case range must fit in bit mask!");
394829037a9SAmara Emerson 
395829037a9SAmara Emerson   // Check if the clusters cover a contiguous range such that no value in the
396829037a9SAmara Emerson   // range will jump to the default statement.
397829037a9SAmara Emerson   bool ContiguousRange = true;
398829037a9SAmara Emerson   for (int64_t I = First + 1; I <= Last; ++I) {
399829037a9SAmara Emerson     if (Clusters[I].Low->getValue() != Clusters[I - 1].High->getValue() + 1) {
400829037a9SAmara Emerson       ContiguousRange = false;
401829037a9SAmara Emerson       break;
402829037a9SAmara Emerson     }
403829037a9SAmara Emerson   }
404829037a9SAmara Emerson 
405829037a9SAmara Emerson   if (Low.isStrictlyPositive() && High.slt(BitWidth)) {
406829037a9SAmara Emerson     // Optimize the case where all the case values fit in a word without having
407829037a9SAmara Emerson     // to subtract minValue. In this case, we can optimize away the subtraction.
408*735f4671SChris Lattner     LowBound = APInt::getZero(Low.getBitWidth());
409829037a9SAmara Emerson     CmpRange = High;
410829037a9SAmara Emerson     ContiguousRange = false;
411829037a9SAmara Emerson   } else {
412829037a9SAmara Emerson     LowBound = Low;
413829037a9SAmara Emerson     CmpRange = High - Low;
414829037a9SAmara Emerson   }
415829037a9SAmara Emerson 
416829037a9SAmara Emerson   CaseBitsVector CBV;
417829037a9SAmara Emerson   auto TotalProb = BranchProbability::getZero();
418829037a9SAmara Emerson   for (unsigned i = First; i <= Last; ++i) {
419829037a9SAmara Emerson     // Find the CaseBits for this destination.
420829037a9SAmara Emerson     unsigned j;
421829037a9SAmara Emerson     for (j = 0; j < CBV.size(); ++j)
422829037a9SAmara Emerson       if (CBV[j].BB == Clusters[i].MBB)
423829037a9SAmara Emerson         break;
424829037a9SAmara Emerson     if (j == CBV.size())
425829037a9SAmara Emerson       CBV.push_back(
426829037a9SAmara Emerson           CaseBits(0, Clusters[i].MBB, 0, BranchProbability::getZero()));
427829037a9SAmara Emerson     CaseBits *CB = &CBV[j];
428829037a9SAmara Emerson 
429829037a9SAmara Emerson     // Update Mask, Bits and ExtraProb.
430829037a9SAmara Emerson     uint64_t Lo = (Clusters[i].Low->getValue() - LowBound).getZExtValue();
431829037a9SAmara Emerson     uint64_t Hi = (Clusters[i].High->getValue() - LowBound).getZExtValue();
432829037a9SAmara Emerson     assert(Hi >= Lo && Hi < 64 && "Invalid bit case!");
433829037a9SAmara Emerson     CB->Mask |= (-1ULL >> (63 - (Hi - Lo))) << Lo;
434829037a9SAmara Emerson     CB->Bits += Hi - Lo + 1;
435829037a9SAmara Emerson     CB->ExtraProb += Clusters[i].Prob;
436829037a9SAmara Emerson     TotalProb += Clusters[i].Prob;
437829037a9SAmara Emerson   }
438829037a9SAmara Emerson 
439829037a9SAmara Emerson   BitTestInfo BTI;
440829037a9SAmara Emerson   llvm::sort(CBV, [](const CaseBits &a, const CaseBits &b) {
441829037a9SAmara Emerson     // Sort by probability first, number of bits second, bit mask third.
442829037a9SAmara Emerson     if (a.ExtraProb != b.ExtraProb)
443829037a9SAmara Emerson       return a.ExtraProb > b.ExtraProb;
444829037a9SAmara Emerson     if (a.Bits != b.Bits)
445829037a9SAmara Emerson       return a.Bits > b.Bits;
446829037a9SAmara Emerson     return a.Mask < b.Mask;
447829037a9SAmara Emerson   });
448829037a9SAmara Emerson 
449829037a9SAmara Emerson   for (auto &CB : CBV) {
450829037a9SAmara Emerson     MachineBasicBlock *BitTestBB =
451829037a9SAmara Emerson         FuncInfo.MF->CreateMachineBasicBlock(SI->getParent());
452829037a9SAmara Emerson     BTI.push_back(BitTestCase(CB.Mask, BitTestBB, CB.BB, CB.ExtraProb));
453829037a9SAmara Emerson   }
454829037a9SAmara Emerson   BitTestCases.emplace_back(std::move(LowBound), std::move(CmpRange),
455829037a9SAmara Emerson                             SI->getCondition(), -1U, MVT::Other, false,
456829037a9SAmara Emerson                             ContiguousRange, nullptr, nullptr, std::move(BTI),
457829037a9SAmara Emerson                             TotalProb);
458829037a9SAmara Emerson 
459829037a9SAmara Emerson   BTCluster = CaseCluster::bitTests(Clusters[First].Low, Clusters[Last].High,
460829037a9SAmara Emerson                                     BitTestCases.size() - 1, TotalProb);
461829037a9SAmara Emerson   return true;
462829037a9SAmara Emerson }
463829037a9SAmara Emerson 
sortAndRangeify(CaseClusterVector & Clusters)464829037a9SAmara Emerson void SwitchCG::sortAndRangeify(CaseClusterVector &Clusters) {
465829037a9SAmara Emerson #ifndef NDEBUG
466829037a9SAmara Emerson   for (const CaseCluster &CC : Clusters)
467829037a9SAmara Emerson     assert(CC.Low == CC.High && "Input clusters must be single-case");
468829037a9SAmara Emerson #endif
469829037a9SAmara Emerson 
470829037a9SAmara Emerson   llvm::sort(Clusters, [](const CaseCluster &a, const CaseCluster &b) {
471829037a9SAmara Emerson     return a.Low->getValue().slt(b.Low->getValue());
472829037a9SAmara Emerson   });
473829037a9SAmara Emerson 
474829037a9SAmara Emerson   // Merge adjacent clusters with the same destination.
475829037a9SAmara Emerson   const unsigned N = Clusters.size();
476829037a9SAmara Emerson   unsigned DstIndex = 0;
477829037a9SAmara Emerson   for (unsigned SrcIndex = 0; SrcIndex < N; ++SrcIndex) {
478829037a9SAmara Emerson     CaseCluster &CC = Clusters[SrcIndex];
479829037a9SAmara Emerson     const ConstantInt *CaseVal = CC.Low;
480829037a9SAmara Emerson     MachineBasicBlock *Succ = CC.MBB;
481829037a9SAmara Emerson 
482829037a9SAmara Emerson     if (DstIndex != 0 && Clusters[DstIndex - 1].MBB == Succ &&
483829037a9SAmara Emerson         (CaseVal->getValue() - Clusters[DstIndex - 1].High->getValue()) == 1) {
484829037a9SAmara Emerson       // If this case has the same successor and is a neighbour, merge it into
485829037a9SAmara Emerson       // the previous cluster.
486829037a9SAmara Emerson       Clusters[DstIndex - 1].High = CaseVal;
487829037a9SAmara Emerson       Clusters[DstIndex - 1].Prob += CC.Prob;
488829037a9SAmara Emerson     } else {
489829037a9SAmara Emerson       std::memmove(&Clusters[DstIndex++], &Clusters[SrcIndex],
490829037a9SAmara Emerson                    sizeof(Clusters[SrcIndex]));
491829037a9SAmara Emerson     }
492829037a9SAmara Emerson   }
493829037a9SAmara Emerson   Clusters.resize(DstIndex);
494829037a9SAmara Emerson }
495