1829037a9SAmara Emerson //===- SwitchLoweringUtils.cpp - Switch Lowering --------------------------===//
2829037a9SAmara Emerson //
3829037a9SAmara Emerson // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4829037a9SAmara Emerson // See https://llvm.org/LICENSE.txt for license information.
5829037a9SAmara Emerson // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6829037a9SAmara Emerson //
7829037a9SAmara Emerson //===----------------------------------------------------------------------===//
8829037a9SAmara Emerson //
9829037a9SAmara Emerson // This file contains switch inst lowering optimizations and utilities for
10829037a9SAmara Emerson // codegen, so that it can be used for both SelectionDAG and GlobalISel.
11829037a9SAmara Emerson //
12829037a9SAmara Emerson //===----------------------------------------------------------------------===//
13829037a9SAmara Emerson
14829037a9SAmara Emerson #include "llvm/CodeGen/SwitchLoweringUtils.h"
151673a080SSimon Pilgrim #include "llvm/CodeGen/FunctionLoweringInfo.h"
161673a080SSimon Pilgrim #include "llvm/CodeGen/MachineJumpTableInfo.h"
17f42f733aSSimon Pilgrim #include "llvm/CodeGen/TargetLowering.h"
18fe0006c8SSimon Pilgrim #include "llvm/Target/TargetMachine.h"
19829037a9SAmara Emerson
20829037a9SAmara Emerson using namespace llvm;
21829037a9SAmara Emerson using namespace SwitchCG;
22829037a9SAmara Emerson
getJumpTableRange(const CaseClusterVector & Clusters,unsigned First,unsigned Last)233740ae3bSHans Wennborg uint64_t SwitchCG::getJumpTableRange(const CaseClusterVector &Clusters,
243740ae3bSHans Wennborg unsigned First, unsigned Last) {
253740ae3bSHans Wennborg assert(Last >= First);
263740ae3bSHans Wennborg const APInt &LowCase = Clusters[First].Low->getValue();
273740ae3bSHans Wennborg const APInt &HighCase = Clusters[Last].High->getValue();
283740ae3bSHans Wennborg assert(LowCase.getBitWidth() == HighCase.getBitWidth());
29829037a9SAmara Emerson
303740ae3bSHans Wennborg // FIXME: A range of consecutive cases has 100% density, but only requires one
313740ae3bSHans Wennborg // comparison to lower. We should discriminate against such consecutive ranges
323740ae3bSHans Wennborg // in jump tables.
333740ae3bSHans Wennborg return (HighCase - LowCase).getLimitedValue((UINT64_MAX - 1) / 100) + 1;
34829037a9SAmara Emerson }
35829037a9SAmara Emerson
363740ae3bSHans Wennborg uint64_t
getJumpTableNumCases(const SmallVectorImpl<unsigned> & TotalCases,unsigned First,unsigned Last)373740ae3bSHans Wennborg SwitchCG::getJumpTableNumCases(const SmallVectorImpl<unsigned> &TotalCases,
383740ae3bSHans Wennborg unsigned First, unsigned Last) {
393740ae3bSHans Wennborg assert(Last >= First);
403740ae3bSHans Wennborg assert(TotalCases[Last] >= TotalCases[First]);
413740ae3bSHans Wennborg uint64_t NumCases =
423740ae3bSHans Wennborg TotalCases[Last] - (First == 0 ? 0 : TotalCases[First - 1]);
433740ae3bSHans Wennborg return NumCases;
44829037a9SAmara Emerson }
45829037a9SAmara Emerson
findJumpTables(CaseClusterVector & Clusters,const SwitchInst * SI,MachineBasicBlock * DefaultMBB,ProfileSummaryInfo * PSI,BlockFrequencyInfo * BFI)46829037a9SAmara Emerson void SwitchCG::SwitchLowering::findJumpTables(CaseClusterVector &Clusters,
47829037a9SAmara Emerson const SwitchInst *SI,
480d987e41SHiroshi Yamauchi MachineBasicBlock *DefaultMBB,
490d987e41SHiroshi Yamauchi ProfileSummaryInfo *PSI,
500d987e41SHiroshi Yamauchi BlockFrequencyInfo *BFI) {
51829037a9SAmara Emerson #ifndef NDEBUG
52829037a9SAmara Emerson // Clusters must be non-empty, sorted, and only contain Range clusters.
53829037a9SAmara Emerson assert(!Clusters.empty());
54829037a9SAmara Emerson for (CaseCluster &C : Clusters)
55829037a9SAmara Emerson assert(C.Kind == CC_Range);
56829037a9SAmara Emerson for (unsigned i = 1, e = Clusters.size(); i < e; ++i)
57829037a9SAmara Emerson assert(Clusters[i - 1].High->getValue().slt(Clusters[i].Low->getValue()));
58829037a9SAmara Emerson #endif
59829037a9SAmara Emerson
60fe4625fbSAmara Emerson assert(TLI && "TLI not set!");
61829037a9SAmara Emerson if (!TLI->areJTsAllowed(SI->getParent()->getParent()))
62829037a9SAmara Emerson return;
63829037a9SAmara Emerson
64829037a9SAmara Emerson const unsigned MinJumpTableEntries = TLI->getMinimumJumpTableEntries();
65829037a9SAmara Emerson const unsigned SmallNumberOfEntries = MinJumpTableEntries / 2;
66829037a9SAmara Emerson
67aa10f050SEvandro Menezes // Bail if not enough cases.
68aa10f050SEvandro Menezes const int64_t N = Clusters.size();
69829037a9SAmara Emerson if (N < 2 || N < MinJumpTableEntries)
70829037a9SAmara Emerson return;
71829037a9SAmara Emerson
723740ae3bSHans Wennborg // Accumulated number of cases in each cluster and those prior to it.
733740ae3bSHans Wennborg SmallVector<unsigned, 8> TotalCases(N);
743740ae3bSHans Wennborg for (unsigned i = 0; i < N; ++i) {
753740ae3bSHans Wennborg const APInt &Hi = Clusters[i].High->getValue();
763740ae3bSHans Wennborg const APInt &Lo = Clusters[i].Low->getValue();
773740ae3bSHans Wennborg TotalCases[i] = (Hi - Lo).getLimitedValue() + 1;
783740ae3bSHans Wennborg if (i != 0)
793740ae3bSHans Wennborg TotalCases[i] += TotalCases[i - 1];
803740ae3bSHans Wennborg }
813740ae3bSHans Wennborg
823740ae3bSHans Wennborg uint64_t Range = getJumpTableRange(Clusters,0, N - 1);
833740ae3bSHans Wennborg uint64_t NumCases = getJumpTableNumCases(TotalCases, 0, N - 1);
843740ae3bSHans Wennborg assert(NumCases < UINT64_MAX / 100);
853740ae3bSHans Wennborg assert(Range >= NumCases);
86aa10f050SEvandro Menezes
87aa10f050SEvandro Menezes // Cheap case: the whole range may be suitable for jump table.
880d987e41SHiroshi Yamauchi if (TLI->isSuitableForJumpTable(SI, NumCases, Range, PSI, BFI)) {
89829037a9SAmara Emerson CaseCluster JTCluster;
90829037a9SAmara Emerson if (buildJumpTable(Clusters, 0, N - 1, SI, DefaultMBB, JTCluster)) {
91829037a9SAmara Emerson Clusters[0] = JTCluster;
92829037a9SAmara Emerson Clusters.resize(1);
93829037a9SAmara Emerson return;
94829037a9SAmara Emerson }
95829037a9SAmara Emerson }
96829037a9SAmara Emerson
97829037a9SAmara Emerson // The algorithm below is not suitable for -O0.
98829037a9SAmara Emerson if (TM->getOptLevel() == CodeGenOpt::None)
99829037a9SAmara Emerson return;
100829037a9SAmara Emerson
101829037a9SAmara Emerson // Split Clusters into minimum number of dense partitions. The algorithm uses
102829037a9SAmara Emerson // the same idea as Kannan & Proebsting "Correction to 'Producing Good Code
103829037a9SAmara Emerson // for the Case Statement'" (1994), but builds the MinPartitions array in
104829037a9SAmara Emerson // reverse order to make it easier to reconstruct the partitions in ascending
105829037a9SAmara Emerson // order. In the choice between two optimal partitionings, it picks the one
106829037a9SAmara Emerson // which yields more jump tables.
107829037a9SAmara Emerson
108829037a9SAmara Emerson // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1].
109829037a9SAmara Emerson SmallVector<unsigned, 8> MinPartitions(N);
110829037a9SAmara Emerson // LastElement[i] is the last element of the partition starting at i.
111829037a9SAmara Emerson SmallVector<unsigned, 8> LastElement(N);
1123740ae3bSHans Wennborg // PartitionsScore[i] is used to break ties when choosing between two
1133740ae3bSHans Wennborg // partitionings resulting in the same number of partitions.
1143740ae3bSHans Wennborg SmallVector<unsigned, 8> PartitionsScore(N);
115829037a9SAmara Emerson // For PartitionsScore, a small number of comparisons is considered as good as
116829037a9SAmara Emerson // a jump table and a single comparison is considered better than a jump
117829037a9SAmara Emerson // table.
118829037a9SAmara Emerson enum PartitionScores : unsigned {
119829037a9SAmara Emerson NoTable = 0,
120829037a9SAmara Emerson Table = 1,
121829037a9SAmara Emerson FewCases = 1,
122829037a9SAmara Emerson SingleCase = 2
123829037a9SAmara Emerson };
124829037a9SAmara Emerson
125829037a9SAmara Emerson // Base case: There is only one way to partition Clusters[N-1].
126829037a9SAmara Emerson MinPartitions[N - 1] = 1;
127829037a9SAmara Emerson LastElement[N - 1] = N - 1;
128829037a9SAmara Emerson PartitionsScore[N - 1] = PartitionScores::SingleCase;
129829037a9SAmara Emerson
130829037a9SAmara Emerson // Note: loop indexes are signed to avoid underflow.
131829037a9SAmara Emerson for (int64_t i = N - 2; i >= 0; i--) {
132829037a9SAmara Emerson // Find optimal partitioning of Clusters[i..N-1].
133829037a9SAmara Emerson // Baseline: Put Clusters[i] into a partition on its own.
134829037a9SAmara Emerson MinPartitions[i] = MinPartitions[i + 1] + 1;
135829037a9SAmara Emerson LastElement[i] = i;
136829037a9SAmara Emerson PartitionsScore[i] = PartitionsScore[i + 1] + PartitionScores::SingleCase;
137829037a9SAmara Emerson
138829037a9SAmara Emerson // Search for a solution that results in fewer partitions.
139829037a9SAmara Emerson for (int64_t j = N - 1; j > i; j--) {
140829037a9SAmara Emerson // Try building a partition from Clusters[i..j].
1413740ae3bSHans Wennborg Range = getJumpTableRange(Clusters, i, j);
1423740ae3bSHans Wennborg NumCases = getJumpTableNumCases(TotalCases, i, j);
1433740ae3bSHans Wennborg assert(NumCases < UINT64_MAX / 100);
1443740ae3bSHans Wennborg assert(Range >= NumCases);
1453740ae3bSHans Wennborg
1460d987e41SHiroshi Yamauchi if (TLI->isSuitableForJumpTable(SI, NumCases, Range, PSI, BFI)) {
147829037a9SAmara Emerson unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);
148829037a9SAmara Emerson unsigned Score = j == N - 1 ? 0 : PartitionsScore[j + 1];
149829037a9SAmara Emerson int64_t NumEntries = j - i + 1;
150829037a9SAmara Emerson
151829037a9SAmara Emerson if (NumEntries == 1)
152829037a9SAmara Emerson Score += PartitionScores::SingleCase;
153829037a9SAmara Emerson else if (NumEntries <= SmallNumberOfEntries)
154829037a9SAmara Emerson Score += PartitionScores::FewCases;
155829037a9SAmara Emerson else if (NumEntries >= MinJumpTableEntries)
156829037a9SAmara Emerson Score += PartitionScores::Table;
157829037a9SAmara Emerson
158829037a9SAmara Emerson // If this leads to fewer partitions, or to the same number of
159829037a9SAmara Emerson // partitions with better score, it is a better partitioning.
160829037a9SAmara Emerson if (NumPartitions < MinPartitions[i] ||
161829037a9SAmara Emerson (NumPartitions == MinPartitions[i] && Score > PartitionsScore[i])) {
162829037a9SAmara Emerson MinPartitions[i] = NumPartitions;
163829037a9SAmara Emerson LastElement[i] = j;
164829037a9SAmara Emerson PartitionsScore[i] = Score;
165829037a9SAmara Emerson }
166829037a9SAmara Emerson }
167829037a9SAmara Emerson }
168829037a9SAmara Emerson }
169829037a9SAmara Emerson
170829037a9SAmara Emerson // Iterate over the partitions, replacing some with jump tables in-place.
171829037a9SAmara Emerson unsigned DstIndex = 0;
172829037a9SAmara Emerson for (unsigned First = 0, Last; First < N; First = Last + 1) {
173829037a9SAmara Emerson Last = LastElement[First];
174829037a9SAmara Emerson assert(Last >= First);
175829037a9SAmara Emerson assert(DstIndex <= First);
176829037a9SAmara Emerson unsigned NumClusters = Last - First + 1;
177829037a9SAmara Emerson
178829037a9SAmara Emerson CaseCluster JTCluster;
179829037a9SAmara Emerson if (NumClusters >= MinJumpTableEntries &&
180829037a9SAmara Emerson buildJumpTable(Clusters, First, Last, SI, DefaultMBB, JTCluster)) {
181829037a9SAmara Emerson Clusters[DstIndex++] = JTCluster;
182829037a9SAmara Emerson } else {
183829037a9SAmara Emerson for (unsigned I = First; I <= Last; ++I)
184829037a9SAmara Emerson std::memmove(&Clusters[DstIndex++], &Clusters[I], sizeof(Clusters[I]));
185829037a9SAmara Emerson }
186829037a9SAmara Emerson }
187829037a9SAmara Emerson Clusters.resize(DstIndex);
188829037a9SAmara Emerson }
189829037a9SAmara Emerson
buildJumpTable(const CaseClusterVector & Clusters,unsigned First,unsigned Last,const SwitchInst * SI,MachineBasicBlock * DefaultMBB,CaseCluster & JTCluster)190829037a9SAmara Emerson bool SwitchCG::SwitchLowering::buildJumpTable(const CaseClusterVector &Clusters,
191829037a9SAmara Emerson unsigned First, unsigned Last,
192829037a9SAmara Emerson const SwitchInst *SI,
193829037a9SAmara Emerson MachineBasicBlock *DefaultMBB,
194829037a9SAmara Emerson CaseCluster &JTCluster) {
195829037a9SAmara Emerson assert(First <= Last);
196829037a9SAmara Emerson
197829037a9SAmara Emerson auto Prob = BranchProbability::getZero();
198829037a9SAmara Emerson unsigned NumCmps = 0;
199829037a9SAmara Emerson std::vector<MachineBasicBlock*> Table;
200829037a9SAmara Emerson DenseMap<MachineBasicBlock*, BranchProbability> JTProbs;
201829037a9SAmara Emerson
202829037a9SAmara Emerson // Initialize probabilities in JTProbs.
203829037a9SAmara Emerson for (unsigned I = First; I <= Last; ++I)
204829037a9SAmara Emerson JTProbs[Clusters[I].MBB] = BranchProbability::getZero();
205829037a9SAmara Emerson
206829037a9SAmara Emerson for (unsigned I = First; I <= Last; ++I) {
207829037a9SAmara Emerson assert(Clusters[I].Kind == CC_Range);
208829037a9SAmara Emerson Prob += Clusters[I].Prob;
209829037a9SAmara Emerson const APInt &Low = Clusters[I].Low->getValue();
210829037a9SAmara Emerson const APInt &High = Clusters[I].High->getValue();
211829037a9SAmara Emerson NumCmps += (Low == High) ? 1 : 2;
212829037a9SAmara Emerson if (I != First) {
213829037a9SAmara Emerson // Fill the gap between this and the previous cluster.
214829037a9SAmara Emerson const APInt &PreviousHigh = Clusters[I - 1].High->getValue();
215829037a9SAmara Emerson assert(PreviousHigh.slt(Low));
216829037a9SAmara Emerson uint64_t Gap = (Low - PreviousHigh).getLimitedValue() - 1;
217829037a9SAmara Emerson for (uint64_t J = 0; J < Gap; J++)
218829037a9SAmara Emerson Table.push_back(DefaultMBB);
219829037a9SAmara Emerson }
220829037a9SAmara Emerson uint64_t ClusterSize = (High - Low).getLimitedValue() + 1;
221829037a9SAmara Emerson for (uint64_t J = 0; J < ClusterSize; ++J)
222829037a9SAmara Emerson Table.push_back(Clusters[I].MBB);
223829037a9SAmara Emerson JTProbs[Clusters[I].MBB] += Clusters[I].Prob;
224829037a9SAmara Emerson }
225829037a9SAmara Emerson
226829037a9SAmara Emerson unsigned NumDests = JTProbs.size();
227829037a9SAmara Emerson if (TLI->isSuitableForBitTests(NumDests, NumCmps,
228829037a9SAmara Emerson Clusters[First].Low->getValue(),
229829037a9SAmara Emerson Clusters[Last].High->getValue(), *DL)) {
230829037a9SAmara Emerson // Clusters[First..Last] should be lowered as bit tests instead.
231829037a9SAmara Emerson return false;
232829037a9SAmara Emerson }
233829037a9SAmara Emerson
234829037a9SAmara Emerson // Create the MBB that will load from and jump through the table.
235829037a9SAmara Emerson // Note: We create it here, but it's not inserted into the function yet.
236829037a9SAmara Emerson MachineFunction *CurMF = FuncInfo.MF;
237829037a9SAmara Emerson MachineBasicBlock *JumpTableMBB =
238829037a9SAmara Emerson CurMF->CreateMachineBasicBlock(SI->getParent());
239829037a9SAmara Emerson
240829037a9SAmara Emerson // Add successors. Note: use table order for determinism.
241829037a9SAmara Emerson SmallPtrSet<MachineBasicBlock *, 8> Done;
242829037a9SAmara Emerson for (MachineBasicBlock *Succ : Table) {
243829037a9SAmara Emerson if (Done.count(Succ))
244829037a9SAmara Emerson continue;
245829037a9SAmara Emerson addSuccessorWithProb(JumpTableMBB, Succ, JTProbs[Succ]);
246829037a9SAmara Emerson Done.insert(Succ);
247829037a9SAmara Emerson }
248829037a9SAmara Emerson JumpTableMBB->normalizeSuccProbs();
249829037a9SAmara Emerson
250829037a9SAmara Emerson unsigned JTI = CurMF->getOrCreateJumpTableInfo(TLI->getJumpTableEncoding())
251829037a9SAmara Emerson ->createJumpTableIndex(Table);
252829037a9SAmara Emerson
253829037a9SAmara Emerson // Set up the jump table info.
254829037a9SAmara Emerson JumpTable JT(-1U, JTI, JumpTableMBB, nullptr);
255829037a9SAmara Emerson JumpTableHeader JTH(Clusters[First].Low->getValue(),
256829037a9SAmara Emerson Clusters[Last].High->getValue(), SI->getCondition(),
257829037a9SAmara Emerson nullptr, false);
258829037a9SAmara Emerson JTCases.emplace_back(std::move(JTH), std::move(JT));
259829037a9SAmara Emerson
260829037a9SAmara Emerson JTCluster = CaseCluster::jumpTable(Clusters[First].Low, Clusters[Last].High,
261829037a9SAmara Emerson JTCases.size() - 1, Prob);
262829037a9SAmara Emerson return true;
263829037a9SAmara Emerson }
264829037a9SAmara Emerson
findBitTestClusters(CaseClusterVector & Clusters,const SwitchInst * SI)265829037a9SAmara Emerson void SwitchCG::SwitchLowering::findBitTestClusters(CaseClusterVector &Clusters,
266829037a9SAmara Emerson const SwitchInst *SI) {
267829037a9SAmara Emerson // Partition Clusters into as few subsets as possible, where each subset has a
268829037a9SAmara Emerson // range that fits in a machine word and has <= 3 unique destinations.
269829037a9SAmara Emerson
270829037a9SAmara Emerson #ifndef NDEBUG
271829037a9SAmara Emerson // Clusters must be sorted and contain Range or JumpTable clusters.
272829037a9SAmara Emerson assert(!Clusters.empty());
273829037a9SAmara Emerson assert(Clusters[0].Kind == CC_Range || Clusters[0].Kind == CC_JumpTable);
274829037a9SAmara Emerson for (const CaseCluster &C : Clusters)
275829037a9SAmara Emerson assert(C.Kind == CC_Range || C.Kind == CC_JumpTable);
276829037a9SAmara Emerson for (unsigned i = 1; i < Clusters.size(); ++i)
277829037a9SAmara Emerson assert(Clusters[i-1].High->getValue().slt(Clusters[i].Low->getValue()));
278829037a9SAmara Emerson #endif
279829037a9SAmara Emerson
280829037a9SAmara Emerson // The algorithm below is not suitable for -O0.
281829037a9SAmara Emerson if (TM->getOptLevel() == CodeGenOpt::None)
282829037a9SAmara Emerson return;
283829037a9SAmara Emerson
284829037a9SAmara Emerson // If target does not have legal shift left, do not emit bit tests at all.
285829037a9SAmara Emerson EVT PTy = TLI->getPointerTy(*DL);
286829037a9SAmara Emerson if (!TLI->isOperationLegal(ISD::SHL, PTy))
287829037a9SAmara Emerson return;
288829037a9SAmara Emerson
289829037a9SAmara Emerson int BitWidth = PTy.getSizeInBits();
290829037a9SAmara Emerson const int64_t N = Clusters.size();
291829037a9SAmara Emerson
292829037a9SAmara Emerson // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1].
293829037a9SAmara Emerson SmallVector<unsigned, 8> MinPartitions(N);
294829037a9SAmara Emerson // LastElement[i] is the last element of the partition starting at i.
295829037a9SAmara Emerson SmallVector<unsigned, 8> LastElement(N);
296829037a9SAmara Emerson
297829037a9SAmara Emerson // FIXME: This might not be the best algorithm for finding bit test clusters.
298829037a9SAmara Emerson
299829037a9SAmara Emerson // Base case: There is only one way to partition Clusters[N-1].
300829037a9SAmara Emerson MinPartitions[N - 1] = 1;
301829037a9SAmara Emerson LastElement[N - 1] = N - 1;
302829037a9SAmara Emerson
303829037a9SAmara Emerson // Note: loop indexes are signed to avoid underflow.
304829037a9SAmara Emerson for (int64_t i = N - 2; i >= 0; --i) {
305829037a9SAmara Emerson // Find optimal partitioning of Clusters[i..N-1].
306829037a9SAmara Emerson // Baseline: Put Clusters[i] into a partition on its own.
307829037a9SAmara Emerson MinPartitions[i] = MinPartitions[i + 1] + 1;
308829037a9SAmara Emerson LastElement[i] = i;
309829037a9SAmara Emerson
310829037a9SAmara Emerson // Search for a solution that results in fewer partitions.
311829037a9SAmara Emerson // Note: the search is limited by BitWidth, reducing time complexity.
312829037a9SAmara Emerson for (int64_t j = std::min(N - 1, i + BitWidth - 1); j > i; --j) {
313829037a9SAmara Emerson // Try building a partition from Clusters[i..j].
314829037a9SAmara Emerson
315829037a9SAmara Emerson // Check the range.
316829037a9SAmara Emerson if (!TLI->rangeFitsInWord(Clusters[i].Low->getValue(),
317829037a9SAmara Emerson Clusters[j].High->getValue(), *DL))
318829037a9SAmara Emerson continue;
319829037a9SAmara Emerson
320829037a9SAmara Emerson // Check nbr of destinations and cluster types.
321829037a9SAmara Emerson // FIXME: This works, but doesn't seem very efficient.
322829037a9SAmara Emerson bool RangesOnly = true;
323829037a9SAmara Emerson BitVector Dests(FuncInfo.MF->getNumBlockIDs());
324829037a9SAmara Emerson for (int64_t k = i; k <= j; k++) {
325829037a9SAmara Emerson if (Clusters[k].Kind != CC_Range) {
326829037a9SAmara Emerson RangesOnly = false;
327829037a9SAmara Emerson break;
328829037a9SAmara Emerson }
329829037a9SAmara Emerson Dests.set(Clusters[k].MBB->getNumber());
330829037a9SAmara Emerson }
331829037a9SAmara Emerson if (!RangesOnly || Dests.count() > 3)
332829037a9SAmara Emerson break;
333829037a9SAmara Emerson
334829037a9SAmara Emerson // Check if it's a better partition.
335829037a9SAmara Emerson unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);
336829037a9SAmara Emerson if (NumPartitions < MinPartitions[i]) {
337829037a9SAmara Emerson // Found a better partition.
338829037a9SAmara Emerson MinPartitions[i] = NumPartitions;
339829037a9SAmara Emerson LastElement[i] = j;
340829037a9SAmara Emerson }
341829037a9SAmara Emerson }
342829037a9SAmara Emerson }
343829037a9SAmara Emerson
344829037a9SAmara Emerson // Iterate over the partitions, replacing with bit-test clusters in-place.
345829037a9SAmara Emerson unsigned DstIndex = 0;
346829037a9SAmara Emerson for (unsigned First = 0, Last; First < N; First = Last + 1) {
347829037a9SAmara Emerson Last = LastElement[First];
348829037a9SAmara Emerson assert(First <= Last);
349829037a9SAmara Emerson assert(DstIndex <= First);
350829037a9SAmara Emerson
351829037a9SAmara Emerson CaseCluster BitTestCluster;
352829037a9SAmara Emerson if (buildBitTests(Clusters, First, Last, SI, BitTestCluster)) {
353829037a9SAmara Emerson Clusters[DstIndex++] = BitTestCluster;
354829037a9SAmara Emerson } else {
355829037a9SAmara Emerson size_t NumClusters = Last - First + 1;
356829037a9SAmara Emerson std::memmove(&Clusters[DstIndex], &Clusters[First],
357829037a9SAmara Emerson sizeof(Clusters[0]) * NumClusters);
358829037a9SAmara Emerson DstIndex += NumClusters;
359829037a9SAmara Emerson }
360829037a9SAmara Emerson }
361829037a9SAmara Emerson Clusters.resize(DstIndex);
362829037a9SAmara Emerson }
363829037a9SAmara Emerson
buildBitTests(CaseClusterVector & Clusters,unsigned First,unsigned Last,const SwitchInst * SI,CaseCluster & BTCluster)364829037a9SAmara Emerson bool SwitchCG::SwitchLowering::buildBitTests(CaseClusterVector &Clusters,
365829037a9SAmara Emerson unsigned First, unsigned Last,
366829037a9SAmara Emerson const SwitchInst *SI,
367829037a9SAmara Emerson CaseCluster &BTCluster) {
368829037a9SAmara Emerson assert(First <= Last);
369829037a9SAmara Emerson if (First == Last)
370829037a9SAmara Emerson return false;
371829037a9SAmara Emerson
372829037a9SAmara Emerson BitVector Dests(FuncInfo.MF->getNumBlockIDs());
373829037a9SAmara Emerson unsigned NumCmps = 0;
374829037a9SAmara Emerson for (int64_t I = First; I <= Last; ++I) {
375829037a9SAmara Emerson assert(Clusters[I].Kind == CC_Range);
376829037a9SAmara Emerson Dests.set(Clusters[I].MBB->getNumber());
377829037a9SAmara Emerson NumCmps += (Clusters[I].Low == Clusters[I].High) ? 1 : 2;
378829037a9SAmara Emerson }
379829037a9SAmara Emerson unsigned NumDests = Dests.count();
380829037a9SAmara Emerson
381829037a9SAmara Emerson APInt Low = Clusters[First].Low->getValue();
382829037a9SAmara Emerson APInt High = Clusters[Last].High->getValue();
383829037a9SAmara Emerson assert(Low.slt(High));
384829037a9SAmara Emerson
385829037a9SAmara Emerson if (!TLI->isSuitableForBitTests(NumDests, NumCmps, Low, High, *DL))
386829037a9SAmara Emerson return false;
387829037a9SAmara Emerson
388829037a9SAmara Emerson APInt LowBound;
389829037a9SAmara Emerson APInt CmpRange;
390829037a9SAmara Emerson
391829037a9SAmara Emerson const int BitWidth = TLI->getPointerTy(*DL).getSizeInBits();
392829037a9SAmara Emerson assert(TLI->rangeFitsInWord(Low, High, *DL) &&
393829037a9SAmara Emerson "Case range must fit in bit mask!");
394829037a9SAmara Emerson
395829037a9SAmara Emerson // Check if the clusters cover a contiguous range such that no value in the
396829037a9SAmara Emerson // range will jump to the default statement.
397829037a9SAmara Emerson bool ContiguousRange = true;
398829037a9SAmara Emerson for (int64_t I = First + 1; I <= Last; ++I) {
399829037a9SAmara Emerson if (Clusters[I].Low->getValue() != Clusters[I - 1].High->getValue() + 1) {
400829037a9SAmara Emerson ContiguousRange = false;
401829037a9SAmara Emerson break;
402829037a9SAmara Emerson }
403829037a9SAmara Emerson }
404829037a9SAmara Emerson
405829037a9SAmara Emerson if (Low.isStrictlyPositive() && High.slt(BitWidth)) {
406829037a9SAmara Emerson // Optimize the case where all the case values fit in a word without having
407829037a9SAmara Emerson // to subtract minValue. In this case, we can optimize away the subtraction.
408*735f4671SChris Lattner LowBound = APInt::getZero(Low.getBitWidth());
409829037a9SAmara Emerson CmpRange = High;
410829037a9SAmara Emerson ContiguousRange = false;
411829037a9SAmara Emerson } else {
412829037a9SAmara Emerson LowBound = Low;
413829037a9SAmara Emerson CmpRange = High - Low;
414829037a9SAmara Emerson }
415829037a9SAmara Emerson
416829037a9SAmara Emerson CaseBitsVector CBV;
417829037a9SAmara Emerson auto TotalProb = BranchProbability::getZero();
418829037a9SAmara Emerson for (unsigned i = First; i <= Last; ++i) {
419829037a9SAmara Emerson // Find the CaseBits for this destination.
420829037a9SAmara Emerson unsigned j;
421829037a9SAmara Emerson for (j = 0; j < CBV.size(); ++j)
422829037a9SAmara Emerson if (CBV[j].BB == Clusters[i].MBB)
423829037a9SAmara Emerson break;
424829037a9SAmara Emerson if (j == CBV.size())
425829037a9SAmara Emerson CBV.push_back(
426829037a9SAmara Emerson CaseBits(0, Clusters[i].MBB, 0, BranchProbability::getZero()));
427829037a9SAmara Emerson CaseBits *CB = &CBV[j];
428829037a9SAmara Emerson
429829037a9SAmara Emerson // Update Mask, Bits and ExtraProb.
430829037a9SAmara Emerson uint64_t Lo = (Clusters[i].Low->getValue() - LowBound).getZExtValue();
431829037a9SAmara Emerson uint64_t Hi = (Clusters[i].High->getValue() - LowBound).getZExtValue();
432829037a9SAmara Emerson assert(Hi >= Lo && Hi < 64 && "Invalid bit case!");
433829037a9SAmara Emerson CB->Mask |= (-1ULL >> (63 - (Hi - Lo))) << Lo;
434829037a9SAmara Emerson CB->Bits += Hi - Lo + 1;
435829037a9SAmara Emerson CB->ExtraProb += Clusters[i].Prob;
436829037a9SAmara Emerson TotalProb += Clusters[i].Prob;
437829037a9SAmara Emerson }
438829037a9SAmara Emerson
439829037a9SAmara Emerson BitTestInfo BTI;
440829037a9SAmara Emerson llvm::sort(CBV, [](const CaseBits &a, const CaseBits &b) {
441829037a9SAmara Emerson // Sort by probability first, number of bits second, bit mask third.
442829037a9SAmara Emerson if (a.ExtraProb != b.ExtraProb)
443829037a9SAmara Emerson return a.ExtraProb > b.ExtraProb;
444829037a9SAmara Emerson if (a.Bits != b.Bits)
445829037a9SAmara Emerson return a.Bits > b.Bits;
446829037a9SAmara Emerson return a.Mask < b.Mask;
447829037a9SAmara Emerson });
448829037a9SAmara Emerson
449829037a9SAmara Emerson for (auto &CB : CBV) {
450829037a9SAmara Emerson MachineBasicBlock *BitTestBB =
451829037a9SAmara Emerson FuncInfo.MF->CreateMachineBasicBlock(SI->getParent());
452829037a9SAmara Emerson BTI.push_back(BitTestCase(CB.Mask, BitTestBB, CB.BB, CB.ExtraProb));
453829037a9SAmara Emerson }
454829037a9SAmara Emerson BitTestCases.emplace_back(std::move(LowBound), std::move(CmpRange),
455829037a9SAmara Emerson SI->getCondition(), -1U, MVT::Other, false,
456829037a9SAmara Emerson ContiguousRange, nullptr, nullptr, std::move(BTI),
457829037a9SAmara Emerson TotalProb);
458829037a9SAmara Emerson
459829037a9SAmara Emerson BTCluster = CaseCluster::bitTests(Clusters[First].Low, Clusters[Last].High,
460829037a9SAmara Emerson BitTestCases.size() - 1, TotalProb);
461829037a9SAmara Emerson return true;
462829037a9SAmara Emerson }
463829037a9SAmara Emerson
sortAndRangeify(CaseClusterVector & Clusters)464829037a9SAmara Emerson void SwitchCG::sortAndRangeify(CaseClusterVector &Clusters) {
465829037a9SAmara Emerson #ifndef NDEBUG
466829037a9SAmara Emerson for (const CaseCluster &CC : Clusters)
467829037a9SAmara Emerson assert(CC.Low == CC.High && "Input clusters must be single-case");
468829037a9SAmara Emerson #endif
469829037a9SAmara Emerson
470829037a9SAmara Emerson llvm::sort(Clusters, [](const CaseCluster &a, const CaseCluster &b) {
471829037a9SAmara Emerson return a.Low->getValue().slt(b.Low->getValue());
472829037a9SAmara Emerson });
473829037a9SAmara Emerson
474829037a9SAmara Emerson // Merge adjacent clusters with the same destination.
475829037a9SAmara Emerson const unsigned N = Clusters.size();
476829037a9SAmara Emerson unsigned DstIndex = 0;
477829037a9SAmara Emerson for (unsigned SrcIndex = 0; SrcIndex < N; ++SrcIndex) {
478829037a9SAmara Emerson CaseCluster &CC = Clusters[SrcIndex];
479829037a9SAmara Emerson const ConstantInt *CaseVal = CC.Low;
480829037a9SAmara Emerson MachineBasicBlock *Succ = CC.MBB;
481829037a9SAmara Emerson
482829037a9SAmara Emerson if (DstIndex != 0 && Clusters[DstIndex - 1].MBB == Succ &&
483829037a9SAmara Emerson (CaseVal->getValue() - Clusters[DstIndex - 1].High->getValue()) == 1) {
484829037a9SAmara Emerson // If this case has the same successor and is a neighbour, merge it into
485829037a9SAmara Emerson // the previous cluster.
486829037a9SAmara Emerson Clusters[DstIndex - 1].High = CaseVal;
487829037a9SAmara Emerson Clusters[DstIndex - 1].Prob += CC.Prob;
488829037a9SAmara Emerson } else {
489829037a9SAmara Emerson std::memmove(&Clusters[DstIndex++], &Clusters[SrcIndex],
490829037a9SAmara Emerson sizeof(Clusters[SrcIndex]));
491829037a9SAmara Emerson }
492829037a9SAmara Emerson }
493829037a9SAmara Emerson Clusters.resize(DstIndex);
494829037a9SAmara Emerson }
495