1 //===- SwitchLoweringUtils.cpp - Switch Lowering --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains switch inst lowering optimizations and utilities for
10 // codegen, so that it can be used for both SelectionDAG and GlobalISel.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/CodeGen/SwitchLoweringUtils.h"
15 #include "llvm/CodeGen/FunctionLoweringInfo.h"
16 #include "llvm/CodeGen/MachineJumpTableInfo.h"
17 #include "llvm/Target/TargetMachine.h"
18 
19 using namespace llvm;
20 using namespace SwitchCG;
21 
22 uint64_t SwitchCG::getJumpTableRange(const CaseClusterVector &Clusters,
23                                      unsigned First, unsigned Last) {
24   assert(Last >= First);
25   const APInt &LowCase = Clusters[First].Low->getValue();
26   const APInt &HighCase = Clusters[Last].High->getValue();
27   assert(LowCase.getBitWidth() == HighCase.getBitWidth());
28 
29   // FIXME: A range of consecutive cases has 100% density, but only requires one
30   // comparison to lower. We should discriminate against such consecutive ranges
31   // in jump tables.
32   return (HighCase - LowCase).getLimitedValue((UINT64_MAX - 1) / 100) + 1;
33 }
34 
35 uint64_t
36 SwitchCG::getJumpTableNumCases(const SmallVectorImpl<unsigned> &TotalCases,
37                                unsigned First, unsigned Last) {
38   assert(Last >= First);
39   assert(TotalCases[Last] >= TotalCases[First]);
40   uint64_t NumCases =
41       TotalCases[Last] - (First == 0 ? 0 : TotalCases[First - 1]);
42   return NumCases;
43 }
44 
45 void SwitchCG::SwitchLowering::findJumpTables(CaseClusterVector &Clusters,
46                                               const SwitchInst *SI,
47                                               MachineBasicBlock *DefaultMBB,
48                                               ProfileSummaryInfo *PSI,
49                                               BlockFrequencyInfo *BFI) {
50 #ifndef NDEBUG
51   // Clusters must be non-empty, sorted, and only contain Range clusters.
52   assert(!Clusters.empty());
53   for (CaseCluster &C : Clusters)
54     assert(C.Kind == CC_Range);
55   for (unsigned i = 1, e = Clusters.size(); i < e; ++i)
56     assert(Clusters[i - 1].High->getValue().slt(Clusters[i].Low->getValue()));
57 #endif
58 
59   assert(TLI && "TLI not set!");
60   if (!TLI->areJTsAllowed(SI->getParent()->getParent()))
61     return;
62 
63   const unsigned MinJumpTableEntries = TLI->getMinimumJumpTableEntries();
64   const unsigned SmallNumberOfEntries = MinJumpTableEntries / 2;
65 
66   // Bail if not enough cases.
67   const int64_t N = Clusters.size();
68   if (N < 2 || N < MinJumpTableEntries)
69     return;
70 
71   // Accumulated number of cases in each cluster and those prior to it.
72   SmallVector<unsigned, 8> TotalCases(N);
73   for (unsigned i = 0; i < N; ++i) {
74     const APInt &Hi = Clusters[i].High->getValue();
75     const APInt &Lo = Clusters[i].Low->getValue();
76     TotalCases[i] = (Hi - Lo).getLimitedValue() + 1;
77     if (i != 0)
78       TotalCases[i] += TotalCases[i - 1];
79   }
80 
81   uint64_t Range = getJumpTableRange(Clusters,0, N - 1);
82   uint64_t NumCases = getJumpTableNumCases(TotalCases, 0, N - 1);
83   assert(NumCases < UINT64_MAX / 100);
84   assert(Range >= NumCases);
85 
86   // Cheap case: the whole range may be suitable for jump table.
87   if (TLI->isSuitableForJumpTable(SI, NumCases, Range, PSI, BFI)) {
88     CaseCluster JTCluster;
89     if (buildJumpTable(Clusters, 0, N - 1, SI, DefaultMBB, JTCluster)) {
90       Clusters[0] = JTCluster;
91       Clusters.resize(1);
92       return;
93     }
94   }
95 
96   // The algorithm below is not suitable for -O0.
97   if (TM->getOptLevel() == CodeGenOpt::None)
98     return;
99 
100   // Split Clusters into minimum number of dense partitions. The algorithm uses
101   // the same idea as Kannan & Proebsting "Correction to 'Producing Good Code
102   // for the Case Statement'" (1994), but builds the MinPartitions array in
103   // reverse order to make it easier to reconstruct the partitions in ascending
104   // order. In the choice between two optimal partitionings, it picks the one
105   // which yields more jump tables.
106 
107   // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1].
108   SmallVector<unsigned, 8> MinPartitions(N);
109   // LastElement[i] is the last element of the partition starting at i.
110   SmallVector<unsigned, 8> LastElement(N);
111   // PartitionsScore[i] is used to break ties when choosing between two
112   // partitionings resulting in the same number of partitions.
113   SmallVector<unsigned, 8> PartitionsScore(N);
114   // For PartitionsScore, a small number of comparisons is considered as good as
115   // a jump table and a single comparison is considered better than a jump
116   // table.
117   enum PartitionScores : unsigned {
118     NoTable = 0,
119     Table = 1,
120     FewCases = 1,
121     SingleCase = 2
122   };
123 
124   // Base case: There is only one way to partition Clusters[N-1].
125   MinPartitions[N - 1] = 1;
126   LastElement[N - 1] = N - 1;
127   PartitionsScore[N - 1] = PartitionScores::SingleCase;
128 
129   // Note: loop indexes are signed to avoid underflow.
130   for (int64_t i = N - 2; i >= 0; i--) {
131     // Find optimal partitioning of Clusters[i..N-1].
132     // Baseline: Put Clusters[i] into a partition on its own.
133     MinPartitions[i] = MinPartitions[i + 1] + 1;
134     LastElement[i] = i;
135     PartitionsScore[i] = PartitionsScore[i + 1] + PartitionScores::SingleCase;
136 
137     // Search for a solution that results in fewer partitions.
138     for (int64_t j = N - 1; j > i; j--) {
139       // Try building a partition from Clusters[i..j].
140       Range = getJumpTableRange(Clusters, i, j);
141       NumCases = getJumpTableNumCases(TotalCases, i, j);
142       assert(NumCases < UINT64_MAX / 100);
143       assert(Range >= NumCases);
144 
145       if (TLI->isSuitableForJumpTable(SI, NumCases, Range, PSI, BFI)) {
146         unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);
147         unsigned Score = j == N - 1 ? 0 : PartitionsScore[j + 1];
148         int64_t NumEntries = j - i + 1;
149 
150         if (NumEntries == 1)
151           Score += PartitionScores::SingleCase;
152         else if (NumEntries <= SmallNumberOfEntries)
153           Score += PartitionScores::FewCases;
154         else if (NumEntries >= MinJumpTableEntries)
155           Score += PartitionScores::Table;
156 
157         // If this leads to fewer partitions, or to the same number of
158         // partitions with better score, it is a better partitioning.
159         if (NumPartitions < MinPartitions[i] ||
160             (NumPartitions == MinPartitions[i] && Score > PartitionsScore[i])) {
161           MinPartitions[i] = NumPartitions;
162           LastElement[i] = j;
163           PartitionsScore[i] = Score;
164         }
165       }
166     }
167   }
168 
169   // Iterate over the partitions, replacing some with jump tables in-place.
170   unsigned DstIndex = 0;
171   for (unsigned First = 0, Last; First < N; First = Last + 1) {
172     Last = LastElement[First];
173     assert(Last >= First);
174     assert(DstIndex <= First);
175     unsigned NumClusters = Last - First + 1;
176 
177     CaseCluster JTCluster;
178     if (NumClusters >= MinJumpTableEntries &&
179         buildJumpTable(Clusters, First, Last, SI, DefaultMBB, JTCluster)) {
180       Clusters[DstIndex++] = JTCluster;
181     } else {
182       for (unsigned I = First; I <= Last; ++I)
183         std::memmove(&Clusters[DstIndex++], &Clusters[I], sizeof(Clusters[I]));
184     }
185   }
186   Clusters.resize(DstIndex);
187 }
188 
189 bool SwitchCG::SwitchLowering::buildJumpTable(const CaseClusterVector &Clusters,
190                                               unsigned First, unsigned Last,
191                                               const SwitchInst *SI,
192                                               MachineBasicBlock *DefaultMBB,
193                                               CaseCluster &JTCluster) {
194   assert(First <= Last);
195 
196   auto Prob = BranchProbability::getZero();
197   unsigned NumCmps = 0;
198   std::vector<MachineBasicBlock*> Table;
199   DenseMap<MachineBasicBlock*, BranchProbability> JTProbs;
200 
201   // Initialize probabilities in JTProbs.
202   for (unsigned I = First; I <= Last; ++I)
203     JTProbs[Clusters[I].MBB] = BranchProbability::getZero();
204 
205   for (unsigned I = First; I <= Last; ++I) {
206     assert(Clusters[I].Kind == CC_Range);
207     Prob += Clusters[I].Prob;
208     const APInt &Low = Clusters[I].Low->getValue();
209     const APInt &High = Clusters[I].High->getValue();
210     NumCmps += (Low == High) ? 1 : 2;
211     if (I != First) {
212       // Fill the gap between this and the previous cluster.
213       const APInt &PreviousHigh = Clusters[I - 1].High->getValue();
214       assert(PreviousHigh.slt(Low));
215       uint64_t Gap = (Low - PreviousHigh).getLimitedValue() - 1;
216       for (uint64_t J = 0; J < Gap; J++)
217         Table.push_back(DefaultMBB);
218     }
219     uint64_t ClusterSize = (High - Low).getLimitedValue() + 1;
220     for (uint64_t J = 0; J < ClusterSize; ++J)
221       Table.push_back(Clusters[I].MBB);
222     JTProbs[Clusters[I].MBB] += Clusters[I].Prob;
223   }
224 
225   unsigned NumDests = JTProbs.size();
226   if (TLI->isSuitableForBitTests(NumDests, NumCmps,
227                                  Clusters[First].Low->getValue(),
228                                  Clusters[Last].High->getValue(), *DL)) {
229     // Clusters[First..Last] should be lowered as bit tests instead.
230     return false;
231   }
232 
233   // Create the MBB that will load from and jump through the table.
234   // Note: We create it here, but it's not inserted into the function yet.
235   MachineFunction *CurMF = FuncInfo.MF;
236   MachineBasicBlock *JumpTableMBB =
237       CurMF->CreateMachineBasicBlock(SI->getParent());
238 
239   // Add successors. Note: use table order for determinism.
240   SmallPtrSet<MachineBasicBlock *, 8> Done;
241   for (MachineBasicBlock *Succ : Table) {
242     if (Done.count(Succ))
243       continue;
244     addSuccessorWithProb(JumpTableMBB, Succ, JTProbs[Succ]);
245     Done.insert(Succ);
246   }
247   JumpTableMBB->normalizeSuccProbs();
248 
249   unsigned JTI = CurMF->getOrCreateJumpTableInfo(TLI->getJumpTableEncoding())
250                      ->createJumpTableIndex(Table);
251 
252   // Set up the jump table info.
253   JumpTable JT(-1U, JTI, JumpTableMBB, nullptr);
254   JumpTableHeader JTH(Clusters[First].Low->getValue(),
255                       Clusters[Last].High->getValue(), SI->getCondition(),
256                       nullptr, false);
257   JTCases.emplace_back(std::move(JTH), std::move(JT));
258 
259   JTCluster = CaseCluster::jumpTable(Clusters[First].Low, Clusters[Last].High,
260                                      JTCases.size() - 1, Prob);
261   return true;
262 }
263 
264 void SwitchCG::SwitchLowering::findBitTestClusters(CaseClusterVector &Clusters,
265                                                    const SwitchInst *SI) {
266   // Partition Clusters into as few subsets as possible, where each subset has a
267   // range that fits in a machine word and has <= 3 unique destinations.
268 
269 #ifndef NDEBUG
270   // Clusters must be sorted and contain Range or JumpTable clusters.
271   assert(!Clusters.empty());
272   assert(Clusters[0].Kind == CC_Range || Clusters[0].Kind == CC_JumpTable);
273   for (const CaseCluster &C : Clusters)
274     assert(C.Kind == CC_Range || C.Kind == CC_JumpTable);
275   for (unsigned i = 1; i < Clusters.size(); ++i)
276     assert(Clusters[i-1].High->getValue().slt(Clusters[i].Low->getValue()));
277 #endif
278 
279   // The algorithm below is not suitable for -O0.
280   if (TM->getOptLevel() == CodeGenOpt::None)
281     return;
282 
283   // If target does not have legal shift left, do not emit bit tests at all.
284   EVT PTy = TLI->getPointerTy(*DL);
285   if (!TLI->isOperationLegal(ISD::SHL, PTy))
286     return;
287 
288   int BitWidth = PTy.getSizeInBits();
289   const int64_t N = Clusters.size();
290 
291   // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1].
292   SmallVector<unsigned, 8> MinPartitions(N);
293   // LastElement[i] is the last element of the partition starting at i.
294   SmallVector<unsigned, 8> LastElement(N);
295 
296   // FIXME: This might not be the best algorithm for finding bit test clusters.
297 
298   // Base case: There is only one way to partition Clusters[N-1].
299   MinPartitions[N - 1] = 1;
300   LastElement[N - 1] = N - 1;
301 
302   // Note: loop indexes are signed to avoid underflow.
303   for (int64_t i = N - 2; i >= 0; --i) {
304     // Find optimal partitioning of Clusters[i..N-1].
305     // Baseline: Put Clusters[i] into a partition on its own.
306     MinPartitions[i] = MinPartitions[i + 1] + 1;
307     LastElement[i] = i;
308 
309     // Search for a solution that results in fewer partitions.
310     // Note: the search is limited by BitWidth, reducing time complexity.
311     for (int64_t j = std::min(N - 1, i + BitWidth - 1); j > i; --j) {
312       // Try building a partition from Clusters[i..j].
313 
314       // Check the range.
315       if (!TLI->rangeFitsInWord(Clusters[i].Low->getValue(),
316                                 Clusters[j].High->getValue(), *DL))
317         continue;
318 
319       // Check nbr of destinations and cluster types.
320       // FIXME: This works, but doesn't seem very efficient.
321       bool RangesOnly = true;
322       BitVector Dests(FuncInfo.MF->getNumBlockIDs());
323       for (int64_t k = i; k <= j; k++) {
324         if (Clusters[k].Kind != CC_Range) {
325           RangesOnly = false;
326           break;
327         }
328         Dests.set(Clusters[k].MBB->getNumber());
329       }
330       if (!RangesOnly || Dests.count() > 3)
331         break;
332 
333       // Check if it's a better partition.
334       unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);
335       if (NumPartitions < MinPartitions[i]) {
336         // Found a better partition.
337         MinPartitions[i] = NumPartitions;
338         LastElement[i] = j;
339       }
340     }
341   }
342 
343   // Iterate over the partitions, replacing with bit-test clusters in-place.
344   unsigned DstIndex = 0;
345   for (unsigned First = 0, Last; First < N; First = Last + 1) {
346     Last = LastElement[First];
347     assert(First <= Last);
348     assert(DstIndex <= First);
349 
350     CaseCluster BitTestCluster;
351     if (buildBitTests(Clusters, First, Last, SI, BitTestCluster)) {
352       Clusters[DstIndex++] = BitTestCluster;
353     } else {
354       size_t NumClusters = Last - First + 1;
355       std::memmove(&Clusters[DstIndex], &Clusters[First],
356                    sizeof(Clusters[0]) * NumClusters);
357       DstIndex += NumClusters;
358     }
359   }
360   Clusters.resize(DstIndex);
361 }
362 
363 bool SwitchCG::SwitchLowering::buildBitTests(CaseClusterVector &Clusters,
364                                              unsigned First, unsigned Last,
365                                              const SwitchInst *SI,
366                                              CaseCluster &BTCluster) {
367   assert(First <= Last);
368   if (First == Last)
369     return false;
370 
371   BitVector Dests(FuncInfo.MF->getNumBlockIDs());
372   unsigned NumCmps = 0;
373   for (int64_t I = First; I <= Last; ++I) {
374     assert(Clusters[I].Kind == CC_Range);
375     Dests.set(Clusters[I].MBB->getNumber());
376     NumCmps += (Clusters[I].Low == Clusters[I].High) ? 1 : 2;
377   }
378   unsigned NumDests = Dests.count();
379 
380   APInt Low = Clusters[First].Low->getValue();
381   APInt High = Clusters[Last].High->getValue();
382   assert(Low.slt(High));
383 
384   if (!TLI->isSuitableForBitTests(NumDests, NumCmps, Low, High, *DL))
385     return false;
386 
387   APInt LowBound;
388   APInt CmpRange;
389 
390   const int BitWidth = TLI->getPointerTy(*DL).getSizeInBits();
391   assert(TLI->rangeFitsInWord(Low, High, *DL) &&
392          "Case range must fit in bit mask!");
393 
394   // Check if the clusters cover a contiguous range such that no value in the
395   // range will jump to the default statement.
396   bool ContiguousRange = true;
397   for (int64_t I = First + 1; I <= Last; ++I) {
398     if (Clusters[I].Low->getValue() != Clusters[I - 1].High->getValue() + 1) {
399       ContiguousRange = false;
400       break;
401     }
402   }
403 
404   if (Low.isStrictlyPositive() && High.slt(BitWidth)) {
405     // Optimize the case where all the case values fit in a word without having
406     // to subtract minValue. In this case, we can optimize away the subtraction.
407     LowBound = APInt::getNullValue(Low.getBitWidth());
408     CmpRange = High;
409     ContiguousRange = false;
410   } else {
411     LowBound = Low;
412     CmpRange = High - Low;
413   }
414 
415   CaseBitsVector CBV;
416   auto TotalProb = BranchProbability::getZero();
417   for (unsigned i = First; i <= Last; ++i) {
418     // Find the CaseBits for this destination.
419     unsigned j;
420     for (j = 0; j < CBV.size(); ++j)
421       if (CBV[j].BB == Clusters[i].MBB)
422         break;
423     if (j == CBV.size())
424       CBV.push_back(
425           CaseBits(0, Clusters[i].MBB, 0, BranchProbability::getZero()));
426     CaseBits *CB = &CBV[j];
427 
428     // Update Mask, Bits and ExtraProb.
429     uint64_t Lo = (Clusters[i].Low->getValue() - LowBound).getZExtValue();
430     uint64_t Hi = (Clusters[i].High->getValue() - LowBound).getZExtValue();
431     assert(Hi >= Lo && Hi < 64 && "Invalid bit case!");
432     CB->Mask |= (-1ULL >> (63 - (Hi - Lo))) << Lo;
433     CB->Bits += Hi - Lo + 1;
434     CB->ExtraProb += Clusters[i].Prob;
435     TotalProb += Clusters[i].Prob;
436   }
437 
438   BitTestInfo BTI;
439   llvm::sort(CBV, [](const CaseBits &a, const CaseBits &b) {
440     // Sort by probability first, number of bits second, bit mask third.
441     if (a.ExtraProb != b.ExtraProb)
442       return a.ExtraProb > b.ExtraProb;
443     if (a.Bits != b.Bits)
444       return a.Bits > b.Bits;
445     return a.Mask < b.Mask;
446   });
447 
448   for (auto &CB : CBV) {
449     MachineBasicBlock *BitTestBB =
450         FuncInfo.MF->CreateMachineBasicBlock(SI->getParent());
451     BTI.push_back(BitTestCase(CB.Mask, BitTestBB, CB.BB, CB.ExtraProb));
452   }
453   BitTestCases.emplace_back(std::move(LowBound), std::move(CmpRange),
454                             SI->getCondition(), -1U, MVT::Other, false,
455                             ContiguousRange, nullptr, nullptr, std::move(BTI),
456                             TotalProb);
457 
458   BTCluster = CaseCluster::bitTests(Clusters[First].Low, Clusters[Last].High,
459                                     BitTestCases.size() - 1, TotalProb);
460   return true;
461 }
462 
463 void SwitchCG::sortAndRangeify(CaseClusterVector &Clusters) {
464 #ifndef NDEBUG
465   for (const CaseCluster &CC : Clusters)
466     assert(CC.Low == CC.High && "Input clusters must be single-case");
467 #endif
468 
469   llvm::sort(Clusters, [](const CaseCluster &a, const CaseCluster &b) {
470     return a.Low->getValue().slt(b.Low->getValue());
471   });
472 
473   // Merge adjacent clusters with the same destination.
474   const unsigned N = Clusters.size();
475   unsigned DstIndex = 0;
476   for (unsigned SrcIndex = 0; SrcIndex < N; ++SrcIndex) {
477     CaseCluster &CC = Clusters[SrcIndex];
478     const ConstantInt *CaseVal = CC.Low;
479     MachineBasicBlock *Succ = CC.MBB;
480 
481     if (DstIndex != 0 && Clusters[DstIndex - 1].MBB == Succ &&
482         (CaseVal->getValue() - Clusters[DstIndex - 1].High->getValue()) == 1) {
483       // If this case has the same successor and is a neighbour, merge it into
484       // the previous cluster.
485       Clusters[DstIndex - 1].High = CaseVal;
486       Clusters[DstIndex - 1].Prob += CC.Prob;
487     } else {
488       std::memmove(&Clusters[DstIndex++], &Clusters[SrcIndex],
489                    sizeof(Clusters[SrcIndex]));
490     }
491   }
492   Clusters.resize(DstIndex);
493 }
494