1 //===- SwitchLoweringUtils.cpp - Switch Lowering --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains switch inst lowering optimizations and utilities for
10 // codegen, so that it can be used for both SelectionDAG and GlobalISel.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/CodeGen/MachineJumpTableInfo.h"
15 #include "llvm/CodeGen/SwitchLoweringUtils.h"
16 
17 using namespace llvm;
18 using namespace SwitchCG;
19 
20 uint64_t SwitchCG::getJumpTableRange(const CaseClusterVector &Clusters,
21                                      unsigned First, unsigned Last) {
22   assert(Last >= First);
23   const APInt &LowCase = Clusters[First].Low->getValue();
24   const APInt &HighCase = Clusters[Last].High->getValue();
25   assert(LowCase.getBitWidth() == HighCase.getBitWidth());
26 
27   // FIXME: A range of consecutive cases has 100% density, but only requires one
28   // comparison to lower. We should discriminate against such consecutive ranges
29   // in jump tables.
30   return (HighCase - LowCase).getLimitedValue((UINT64_MAX - 1) / 100) + 1;
31 }
32 
33 uint64_t
34 SwitchCG::getJumpTableNumCases(const SmallVectorImpl<unsigned> &TotalCases,
35                                unsigned First, unsigned Last) {
36   assert(Last >= First);
37   assert(TotalCases[Last] >= TotalCases[First]);
38   uint64_t NumCases =
39       TotalCases[Last] - (First == 0 ? 0 : TotalCases[First - 1]);
40   return NumCases;
41 }
42 
43 void SwitchCG::SwitchLowering::findJumpTables(CaseClusterVector &Clusters,
44                                               const SwitchInst *SI,
45                                               MachineBasicBlock *DefaultMBB) {
46 #ifndef NDEBUG
47   // Clusters must be non-empty, sorted, and only contain Range clusters.
48   assert(!Clusters.empty());
49   for (CaseCluster &C : Clusters)
50     assert(C.Kind == CC_Range);
51   for (unsigned i = 1, e = Clusters.size(); i < e; ++i)
52     assert(Clusters[i - 1].High->getValue().slt(Clusters[i].Low->getValue()));
53 #endif
54 
55   if (!TLI->areJTsAllowed(SI->getParent()->getParent()))
56     return;
57 
58   const unsigned MinJumpTableEntries = TLI->getMinimumJumpTableEntries();
59   const unsigned SmallNumberOfEntries = MinJumpTableEntries / 2;
60 
61   // Bail if not enough cases.
62   const int64_t N = Clusters.size();
63   if (N < 2 || N < MinJumpTableEntries)
64     return;
65 
66   // Accumulated number of cases in each cluster and those prior to it.
67   SmallVector<unsigned, 8> TotalCases(N);
68   for (unsigned i = 0; i < N; ++i) {
69     const APInt &Hi = Clusters[i].High->getValue();
70     const APInt &Lo = Clusters[i].Low->getValue();
71     TotalCases[i] = (Hi - Lo).getLimitedValue() + 1;
72     if (i != 0)
73       TotalCases[i] += TotalCases[i - 1];
74   }
75 
76   uint64_t Range = getJumpTableRange(Clusters,0, N - 1);
77   uint64_t NumCases = getJumpTableNumCases(TotalCases, 0, N - 1);
78   assert(NumCases < UINT64_MAX / 100);
79   assert(Range >= NumCases);
80 
81   // Cheap case: the whole range may be suitable for jump table.
82   if (TLI->isSuitableForJumpTable(SI, NumCases, Range)) {
83     CaseCluster JTCluster;
84     if (buildJumpTable(Clusters, 0, N - 1, SI, DefaultMBB, JTCluster)) {
85       Clusters[0] = JTCluster;
86       Clusters.resize(1);
87       return;
88     }
89   }
90 
91   // The algorithm below is not suitable for -O0.
92   if (TM->getOptLevel() == CodeGenOpt::None)
93     return;
94 
95   // Split Clusters into minimum number of dense partitions. The algorithm uses
96   // the same idea as Kannan & Proebsting "Correction to 'Producing Good Code
97   // for the Case Statement'" (1994), but builds the MinPartitions array in
98   // reverse order to make it easier to reconstruct the partitions in ascending
99   // order. In the choice between two optimal partitionings, it picks the one
100   // which yields more jump tables.
101 
102   // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1].
103   SmallVector<unsigned, 8> MinPartitions(N);
104   // LastElement[i] is the last element of the partition starting at i.
105   SmallVector<unsigned, 8> LastElement(N);
106   // PartitionsScore[i] is used to break ties when choosing between two
107   // partitionings resulting in the same number of partitions.
108   SmallVector<unsigned, 8> PartitionsScore(N);
109   // For PartitionsScore, a small number of comparisons is considered as good as
110   // a jump table and a single comparison is considered better than a jump
111   // table.
112   enum PartitionScores : unsigned {
113     NoTable = 0,
114     Table = 1,
115     FewCases = 1,
116     SingleCase = 2
117   };
118 
119   // Base case: There is only one way to partition Clusters[N-1].
120   MinPartitions[N - 1] = 1;
121   LastElement[N - 1] = N - 1;
122   PartitionsScore[N - 1] = PartitionScores::SingleCase;
123 
124   // Note: loop indexes are signed to avoid underflow.
125   for (int64_t i = N - 2; i >= 0; i--) {
126     // Find optimal partitioning of Clusters[i..N-1].
127     // Baseline: Put Clusters[i] into a partition on its own.
128     MinPartitions[i] = MinPartitions[i + 1] + 1;
129     LastElement[i] = i;
130     PartitionsScore[i] = PartitionsScore[i + 1] + PartitionScores::SingleCase;
131 
132     // Search for a solution that results in fewer partitions.
133     for (int64_t j = N - 1; j > i; j--) {
134       // Try building a partition from Clusters[i..j].
135       Range = getJumpTableRange(Clusters, i, j);
136       NumCases = getJumpTableNumCases(TotalCases, i, j);
137       assert(NumCases < UINT64_MAX / 100);
138       assert(Range >= NumCases);
139 
140       if (TLI->isSuitableForJumpTable(SI, NumCases, Range)) {
141         unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);
142         unsigned Score = j == N - 1 ? 0 : PartitionsScore[j + 1];
143         int64_t NumEntries = j - i + 1;
144 
145         if (NumEntries == 1)
146           Score += PartitionScores::SingleCase;
147         else if (NumEntries <= SmallNumberOfEntries)
148           Score += PartitionScores::FewCases;
149         else if (NumEntries >= MinJumpTableEntries)
150           Score += PartitionScores::Table;
151 
152         // If this leads to fewer partitions, or to the same number of
153         // partitions with better score, it is a better partitioning.
154         if (NumPartitions < MinPartitions[i] ||
155             (NumPartitions == MinPartitions[i] && Score > PartitionsScore[i])) {
156           MinPartitions[i] = NumPartitions;
157           LastElement[i] = j;
158           PartitionsScore[i] = Score;
159         }
160       }
161     }
162   }
163 
164   // Iterate over the partitions, replacing some with jump tables in-place.
165   unsigned DstIndex = 0;
166   for (unsigned First = 0, Last; First < N; First = Last + 1) {
167     Last = LastElement[First];
168     assert(Last >= First);
169     assert(DstIndex <= First);
170     unsigned NumClusters = Last - First + 1;
171 
172     CaseCluster JTCluster;
173     if (NumClusters >= MinJumpTableEntries &&
174         buildJumpTable(Clusters, First, Last, SI, DefaultMBB, JTCluster)) {
175       Clusters[DstIndex++] = JTCluster;
176     } else {
177       for (unsigned I = First; I <= Last; ++I)
178         std::memmove(&Clusters[DstIndex++], &Clusters[I], sizeof(Clusters[I]));
179     }
180   }
181   Clusters.resize(DstIndex);
182 }
183 
184 bool SwitchCG::SwitchLowering::buildJumpTable(const CaseClusterVector &Clusters,
185                                               unsigned First, unsigned Last,
186                                               const SwitchInst *SI,
187                                               MachineBasicBlock *DefaultMBB,
188                                               CaseCluster &JTCluster) {
189   assert(First <= Last);
190 
191   auto Prob = BranchProbability::getZero();
192   unsigned NumCmps = 0;
193   std::vector<MachineBasicBlock*> Table;
194   DenseMap<MachineBasicBlock*, BranchProbability> JTProbs;
195 
196   // Initialize probabilities in JTProbs.
197   for (unsigned I = First; I <= Last; ++I)
198     JTProbs[Clusters[I].MBB] = BranchProbability::getZero();
199 
200   for (unsigned I = First; I <= Last; ++I) {
201     assert(Clusters[I].Kind == CC_Range);
202     Prob += Clusters[I].Prob;
203     const APInt &Low = Clusters[I].Low->getValue();
204     const APInt &High = Clusters[I].High->getValue();
205     NumCmps += (Low == High) ? 1 : 2;
206     if (I != First) {
207       // Fill the gap between this and the previous cluster.
208       const APInt &PreviousHigh = Clusters[I - 1].High->getValue();
209       assert(PreviousHigh.slt(Low));
210       uint64_t Gap = (Low - PreviousHigh).getLimitedValue() - 1;
211       for (uint64_t J = 0; J < Gap; J++)
212         Table.push_back(DefaultMBB);
213     }
214     uint64_t ClusterSize = (High - Low).getLimitedValue() + 1;
215     for (uint64_t J = 0; J < ClusterSize; ++J)
216       Table.push_back(Clusters[I].MBB);
217     JTProbs[Clusters[I].MBB] += Clusters[I].Prob;
218   }
219 
220   unsigned NumDests = JTProbs.size();
221   if (TLI->isSuitableForBitTests(NumDests, NumCmps,
222                                  Clusters[First].Low->getValue(),
223                                  Clusters[Last].High->getValue(), *DL)) {
224     // Clusters[First..Last] should be lowered as bit tests instead.
225     return false;
226   }
227 
228   // Create the MBB that will load from and jump through the table.
229   // Note: We create it here, but it's not inserted into the function yet.
230   MachineFunction *CurMF = FuncInfo.MF;
231   MachineBasicBlock *JumpTableMBB =
232       CurMF->CreateMachineBasicBlock(SI->getParent());
233 
234   // Add successors. Note: use table order for determinism.
235   SmallPtrSet<MachineBasicBlock *, 8> Done;
236   for (MachineBasicBlock *Succ : Table) {
237     if (Done.count(Succ))
238       continue;
239     addSuccessorWithProb(JumpTableMBB, Succ, JTProbs[Succ]);
240     Done.insert(Succ);
241   }
242   JumpTableMBB->normalizeSuccProbs();
243 
244   unsigned JTI = CurMF->getOrCreateJumpTableInfo(TLI->getJumpTableEncoding())
245                      ->createJumpTableIndex(Table);
246 
247   // Set up the jump table info.
248   JumpTable JT(-1U, JTI, JumpTableMBB, nullptr);
249   JumpTableHeader JTH(Clusters[First].Low->getValue(),
250                       Clusters[Last].High->getValue(), SI->getCondition(),
251                       nullptr, false);
252   JTCases.emplace_back(std::move(JTH), std::move(JT));
253 
254   JTCluster = CaseCluster::jumpTable(Clusters[First].Low, Clusters[Last].High,
255                                      JTCases.size() - 1, Prob);
256   return true;
257 }
258 
259 void SwitchCG::SwitchLowering::findBitTestClusters(CaseClusterVector &Clusters,
260                                                    const SwitchInst *SI) {
261   // Partition Clusters into as few subsets as possible, where each subset has a
262   // range that fits in a machine word and has <= 3 unique destinations.
263 
264 #ifndef NDEBUG
265   // Clusters must be sorted and contain Range or JumpTable clusters.
266   assert(!Clusters.empty());
267   assert(Clusters[0].Kind == CC_Range || Clusters[0].Kind == CC_JumpTable);
268   for (const CaseCluster &C : Clusters)
269     assert(C.Kind == CC_Range || C.Kind == CC_JumpTable);
270   for (unsigned i = 1; i < Clusters.size(); ++i)
271     assert(Clusters[i-1].High->getValue().slt(Clusters[i].Low->getValue()));
272 #endif
273 
274   // The algorithm below is not suitable for -O0.
275   if (TM->getOptLevel() == CodeGenOpt::None)
276     return;
277 
278   // If target does not have legal shift left, do not emit bit tests at all.
279   EVT PTy = TLI->getPointerTy(*DL);
280   if (!TLI->isOperationLegal(ISD::SHL, PTy))
281     return;
282 
283   int BitWidth = PTy.getSizeInBits();
284   const int64_t N = Clusters.size();
285 
286   // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1].
287   SmallVector<unsigned, 8> MinPartitions(N);
288   // LastElement[i] is the last element of the partition starting at i.
289   SmallVector<unsigned, 8> LastElement(N);
290 
291   // FIXME: This might not be the best algorithm for finding bit test clusters.
292 
293   // Base case: There is only one way to partition Clusters[N-1].
294   MinPartitions[N - 1] = 1;
295   LastElement[N - 1] = N - 1;
296 
297   // Note: loop indexes are signed to avoid underflow.
298   for (int64_t i = N - 2; i >= 0; --i) {
299     // Find optimal partitioning of Clusters[i..N-1].
300     // Baseline: Put Clusters[i] into a partition on its own.
301     MinPartitions[i] = MinPartitions[i + 1] + 1;
302     LastElement[i] = i;
303 
304     // Search for a solution that results in fewer partitions.
305     // Note: the search is limited by BitWidth, reducing time complexity.
306     for (int64_t j = std::min(N - 1, i + BitWidth - 1); j > i; --j) {
307       // Try building a partition from Clusters[i..j].
308 
309       // Check the range.
310       if (!TLI->rangeFitsInWord(Clusters[i].Low->getValue(),
311                                 Clusters[j].High->getValue(), *DL))
312         continue;
313 
314       // Check nbr of destinations and cluster types.
315       // FIXME: This works, but doesn't seem very efficient.
316       bool RangesOnly = true;
317       BitVector Dests(FuncInfo.MF->getNumBlockIDs());
318       for (int64_t k = i; k <= j; k++) {
319         if (Clusters[k].Kind != CC_Range) {
320           RangesOnly = false;
321           break;
322         }
323         Dests.set(Clusters[k].MBB->getNumber());
324       }
325       if (!RangesOnly || Dests.count() > 3)
326         break;
327 
328       // Check if it's a better partition.
329       unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);
330       if (NumPartitions < MinPartitions[i]) {
331         // Found a better partition.
332         MinPartitions[i] = NumPartitions;
333         LastElement[i] = j;
334       }
335     }
336   }
337 
338   // Iterate over the partitions, replacing with bit-test clusters in-place.
339   unsigned DstIndex = 0;
340   for (unsigned First = 0, Last; First < N; First = Last + 1) {
341     Last = LastElement[First];
342     assert(First <= Last);
343     assert(DstIndex <= First);
344 
345     CaseCluster BitTestCluster;
346     if (buildBitTests(Clusters, First, Last, SI, BitTestCluster)) {
347       Clusters[DstIndex++] = BitTestCluster;
348     } else {
349       size_t NumClusters = Last - First + 1;
350       std::memmove(&Clusters[DstIndex], &Clusters[First],
351                    sizeof(Clusters[0]) * NumClusters);
352       DstIndex += NumClusters;
353     }
354   }
355   Clusters.resize(DstIndex);
356 }
357 
358 bool SwitchCG::SwitchLowering::buildBitTests(CaseClusterVector &Clusters,
359                                              unsigned First, unsigned Last,
360                                              const SwitchInst *SI,
361                                              CaseCluster &BTCluster) {
362   assert(First <= Last);
363   if (First == Last)
364     return false;
365 
366   BitVector Dests(FuncInfo.MF->getNumBlockIDs());
367   unsigned NumCmps = 0;
368   for (int64_t I = First; I <= Last; ++I) {
369     assert(Clusters[I].Kind == CC_Range);
370     Dests.set(Clusters[I].MBB->getNumber());
371     NumCmps += (Clusters[I].Low == Clusters[I].High) ? 1 : 2;
372   }
373   unsigned NumDests = Dests.count();
374 
375   APInt Low = Clusters[First].Low->getValue();
376   APInt High = Clusters[Last].High->getValue();
377   assert(Low.slt(High));
378 
379   if (!TLI->isSuitableForBitTests(NumDests, NumCmps, Low, High, *DL))
380     return false;
381 
382   APInt LowBound;
383   APInt CmpRange;
384 
385   const int BitWidth = TLI->getPointerTy(*DL).getSizeInBits();
386   assert(TLI->rangeFitsInWord(Low, High, *DL) &&
387          "Case range must fit in bit mask!");
388 
389   // Check if the clusters cover a contiguous range such that no value in the
390   // range will jump to the default statement.
391   bool ContiguousRange = true;
392   for (int64_t I = First + 1; I <= Last; ++I) {
393     if (Clusters[I].Low->getValue() != Clusters[I - 1].High->getValue() + 1) {
394       ContiguousRange = false;
395       break;
396     }
397   }
398 
399   if (Low.isStrictlyPositive() && High.slt(BitWidth)) {
400     // Optimize the case where all the case values fit in a word without having
401     // to subtract minValue. In this case, we can optimize away the subtraction.
402     LowBound = APInt::getNullValue(Low.getBitWidth());
403     CmpRange = High;
404     ContiguousRange = false;
405   } else {
406     LowBound = Low;
407     CmpRange = High - Low;
408   }
409 
410   CaseBitsVector CBV;
411   auto TotalProb = BranchProbability::getZero();
412   for (unsigned i = First; i <= Last; ++i) {
413     // Find the CaseBits for this destination.
414     unsigned j;
415     for (j = 0; j < CBV.size(); ++j)
416       if (CBV[j].BB == Clusters[i].MBB)
417         break;
418     if (j == CBV.size())
419       CBV.push_back(
420           CaseBits(0, Clusters[i].MBB, 0, BranchProbability::getZero()));
421     CaseBits *CB = &CBV[j];
422 
423     // Update Mask, Bits and ExtraProb.
424     uint64_t Lo = (Clusters[i].Low->getValue() - LowBound).getZExtValue();
425     uint64_t Hi = (Clusters[i].High->getValue() - LowBound).getZExtValue();
426     assert(Hi >= Lo && Hi < 64 && "Invalid bit case!");
427     CB->Mask |= (-1ULL >> (63 - (Hi - Lo))) << Lo;
428     CB->Bits += Hi - Lo + 1;
429     CB->ExtraProb += Clusters[i].Prob;
430     TotalProb += Clusters[i].Prob;
431   }
432 
433   BitTestInfo BTI;
434   llvm::sort(CBV, [](const CaseBits &a, const CaseBits &b) {
435     // Sort by probability first, number of bits second, bit mask third.
436     if (a.ExtraProb != b.ExtraProb)
437       return a.ExtraProb > b.ExtraProb;
438     if (a.Bits != b.Bits)
439       return a.Bits > b.Bits;
440     return a.Mask < b.Mask;
441   });
442 
443   for (auto &CB : CBV) {
444     MachineBasicBlock *BitTestBB =
445         FuncInfo.MF->CreateMachineBasicBlock(SI->getParent());
446     BTI.push_back(BitTestCase(CB.Mask, BitTestBB, CB.BB, CB.ExtraProb));
447   }
448   BitTestCases.emplace_back(std::move(LowBound), std::move(CmpRange),
449                             SI->getCondition(), -1U, MVT::Other, false,
450                             ContiguousRange, nullptr, nullptr, std::move(BTI),
451                             TotalProb);
452 
453   BTCluster = CaseCluster::bitTests(Clusters[First].Low, Clusters[Last].High,
454                                     BitTestCases.size() - 1, TotalProb);
455   return true;
456 }
457 
458 void SwitchCG::sortAndRangeify(CaseClusterVector &Clusters) {
459 #ifndef NDEBUG
460   for (const CaseCluster &CC : Clusters)
461     assert(CC.Low == CC.High && "Input clusters must be single-case");
462 #endif
463 
464   llvm::sort(Clusters, [](const CaseCluster &a, const CaseCluster &b) {
465     return a.Low->getValue().slt(b.Low->getValue());
466   });
467 
468   // Merge adjacent clusters with the same destination.
469   const unsigned N = Clusters.size();
470   unsigned DstIndex = 0;
471   for (unsigned SrcIndex = 0; SrcIndex < N; ++SrcIndex) {
472     CaseCluster &CC = Clusters[SrcIndex];
473     const ConstantInt *CaseVal = CC.Low;
474     MachineBasicBlock *Succ = CC.MBB;
475 
476     if (DstIndex != 0 && Clusters[DstIndex - 1].MBB == Succ &&
477         (CaseVal->getValue() - Clusters[DstIndex - 1].High->getValue()) == 1) {
478       // If this case has the same successor and is a neighbour, merge it into
479       // the previous cluster.
480       Clusters[DstIndex - 1].High = CaseVal;
481       Clusters[DstIndex - 1].Prob += CC.Prob;
482     } else {
483       std::memmove(&Clusters[DstIndex++], &Clusters[SrcIndex],
484                    sizeof(Clusters[SrcIndex]));
485     }
486   }
487   Clusters.resize(DstIndex);
488 }
489