1 //===- SwitchLoweringUtils.cpp - Switch Lowering --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains switch inst lowering optimizations and utilities for
10 // codegen, so that it can be used for both SelectionDAG and GlobalISel.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/ADT/SmallSet.h"
15 #include "llvm/CodeGen/MachineJumpTableInfo.h"
16 #include "llvm/CodeGen/SwitchLoweringUtils.h"
17 
18 using namespace llvm;
19 using namespace SwitchCG;
20 
21 // Collection of partition stats, made up of, for a given cluster,
22 // the range of the cases, their number and the number of unique targets.
23 struct PartitionStats {
24   uint64_t Range, Cases, Targets;
25 };
26 
27 static PartitionStats getJumpTableStats(const CaseClusterVector &Clusters,
28                                         unsigned First, unsigned Last,
29                                         bool HasReachableDefault) {
30   assert(Last >= First && "Invalid order of clusters");
31 
32   SmallSet<const MachineBasicBlock *, 8> Targets;
33   PartitionStats Stats;
34 
35   Stats.Cases = 0;
36   for (unsigned i = First; i <= Last; ++i) {
37     const APInt &Hi = Clusters[i].High->getValue(),
38                 &Lo = Clusters[i].Low->getValue();
39     Stats.Cases += (Hi - Lo).getLimitedValue() + 1;
40 
41     Targets.insert(Clusters[i].MBB);
42   }
43   assert(Stats.Cases < UINT64_MAX / 100 && "Too many cases");
44 
45   const APInt &Hi = Clusters[Last].High->getValue(),
46               &Lo = Clusters[First].Low->getValue();
47   assert(Hi.getBitWidth() == Lo.getBitWidth());
48   Stats.Range = (Hi - Lo).getLimitedValue((UINT64_MAX - 1) / 100) + 1;
49   assert(Stats.Range >= Stats.Cases && "Invalid range or number of cases");
50 
51   Stats.Targets =
52       Targets.size() + (HasReachableDefault && Stats.Range > Stats.Cases);
53 
54   return Stats;
55 }
56 
57 void SwitchCG::SwitchLowering::findJumpTables(CaseClusterVector &Clusters,
58                                               const SwitchInst *SI,
59                                               MachineBasicBlock *DefaultMBB) {
60 #ifndef NDEBUG
61   // Clusters must be non-empty, sorted, and only contain Range clusters.
62   assert(!Clusters.empty());
63   for (CaseCluster &C : Clusters)
64     assert(C.Kind == CC_Range);
65   for (unsigned i = 1, e = Clusters.size(); i < e; ++i)
66     assert(Clusters[i - 1].High->getValue().slt(Clusters[i].Low->getValue()));
67 #endif
68 
69   assert(TLI && "TLI not set!");
70   if (!TLI->areJTsAllowed(SI->getParent()->getParent()))
71     return;
72 
73   const unsigned MinJumpTableEntries = TLI->getMinimumJumpTableEntries();
74   const unsigned SmallNumberOfEntries = MinJumpTableEntries / 2;
75 
76   // Bail if not enough cases.
77   const int64_t N = Clusters.size();
78   if (N < 2 || N < MinJumpTableEntries)
79     return;
80 
81   const bool HasReachableDefault =
82       !isa<UnreachableInst>(DefaultMBB->getBasicBlock()->getFirstNonPHIOrDbg());
83   PartitionStats Stats =
84       getJumpTableStats(Clusters, 0, N - 1, HasReachableDefault);
85 
86   // Cheap case: the whole range may be suitable for jump table.
87   if (TLI->isSuitableForJumpTable(SI, Stats.Cases, Stats.Targets, Stats.Range)) {
88     CaseCluster JTCluster;
89     if (buildJumpTable(Clusters, 0, N - 1, SI, DefaultMBB, JTCluster)) {
90       Clusters[0] = JTCluster;
91       Clusters.resize(1);
92       return;
93     }
94   }
95 
96   // The algorithm below is not suitable for -O0.
97   if (TM->getOptLevel() == CodeGenOpt::None)
98     return;
99 
100   // Split Clusters into minimum number of dense partitions. The algorithm uses
101   // the same idea as Kannan & Proebsting "Correction to 'Producing Good Code
102   // for the Case Statement'" (1994), but builds the MinPartitions array in
103   // reverse order to make it easier to reconstruct the partitions in ascending
104   // order. In the choice between two optimal partitionings, it picks the one
105   // which yields more jump tables.
106 
107   // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1].
108   SmallVector<unsigned, 8> MinPartitions(N);
109   // LastElement[i] is the last element of the partition starting at i.
110   SmallVector<unsigned, 8> LastElement(N);
111   // For PartitionsScore, a small number of comparisons is considered as good as
112   // a jump table and a single comparison is considered better than a jump
113   // table.
114   enum PartitionScores : unsigned {
115     NoTable = 0,
116     Table = 1,
117     FewCases = 1,
118     SingleCase = 2
119   };
120   // PartitionsScore[i] is used to break ties when choosing between two
121   // partitionings resulting in the same number of partitions.
122   SmallVector<unsigned, 8> PartitionsScore(N);
123   // PartitionsStats[j] is the stats for the partition Clusters[i..j].
124   SmallVector<PartitionStats, 8> PartitionsStats(N);
125 
126   // Base case: There is only one way to partition Clusters[N-1].
127   MinPartitions[N - 1] = 1;
128   LastElement[N - 1] = N - 1;
129   PartitionsScore[N - 1] = PartitionScores::SingleCase;
130 
131   // Note: loop indexes are signed to avoid underflow.
132   for (int64_t i = N - 2; i >= 0; i--) {
133     // Find optimal partitioning of Clusters[i..N-1].
134     // Baseline: Put Clusters[i] into a partition on its own.
135     MinPartitions[i] = MinPartitions[i + 1] + 1;
136     LastElement[i] = i;
137     PartitionsScore[i] = PartitionsScore[i + 1] + PartitionScores::SingleCase;
138     for (int64_t j = i + 1; j < N; j++)
139       PartitionsStats[j] =
140           getJumpTableStats(Clusters, i, j, HasReachableDefault);
141 
142     // Search for a solution that results in fewer partitions.
143     for (int64_t j = N - 1; j > i; j--) {
144       // Try building a partition from Clusters[i..j].
145       if (TLI->isSuitableForJumpTable(SI, PartitionsStats[j].Cases,
146                                       PartitionsStats[j].Targets,
147                                       PartitionsStats[j].Range)) {
148         unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);
149         unsigned Score = j == N - 1 ? 0 : PartitionsScore[j + 1];
150         int64_t NumEntries = j - i + 1;
151 
152         if (NumEntries == 1)
153           Score += PartitionScores::SingleCase;
154         else if (NumEntries <= SmallNumberOfEntries)
155           Score += PartitionScores::FewCases;
156         else if (NumEntries >= MinJumpTableEntries)
157           Score += PartitionScores::Table;
158 
159         // If this leads to fewer partitions, or to the same number of
160         // partitions with better score, it is a better partitioning.
161         if (NumPartitions < MinPartitions[i] ||
162             (NumPartitions == MinPartitions[i] && Score > PartitionsScore[i])) {
163           MinPartitions[i] = NumPartitions;
164           LastElement[i] = j;
165           PartitionsScore[i] = Score;
166         }
167       }
168     }
169   }
170 
171   // Iterate over the partitions, replacing some with jump tables in-place.
172   unsigned DstIndex = 0;
173   for (unsigned First = 0, Last; First < N; First = Last + 1) {
174     Last = LastElement[First];
175     assert(Last >= First);
176     assert(DstIndex <= First);
177     unsigned NumClusters = Last - First + 1;
178 
179     CaseCluster JTCluster;
180     if (NumClusters >= MinJumpTableEntries &&
181         buildJumpTable(Clusters, First, Last, SI, DefaultMBB, JTCluster)) {
182       Clusters[DstIndex++] = JTCluster;
183     } else {
184       for (unsigned I = First; I <= Last; ++I)
185         std::memmove(&Clusters[DstIndex++], &Clusters[I], sizeof(Clusters[I]));
186     }
187   }
188   Clusters.resize(DstIndex);
189 }
190 
191 bool SwitchCG::SwitchLowering::buildJumpTable(const CaseClusterVector &Clusters,
192                                               unsigned First, unsigned Last,
193                                               const SwitchInst *SI,
194                                               MachineBasicBlock *DefaultMBB,
195                                               CaseCluster &JTCluster) {
196   assert(First <= Last);
197 
198   auto Prob = BranchProbability::getZero();
199   unsigned NumCmps = 0;
200   std::vector<MachineBasicBlock*> Table;
201   DenseMap<MachineBasicBlock*, BranchProbability> JTProbs;
202 
203   // Initialize probabilities in JTProbs.
204   for (unsigned I = First; I <= Last; ++I)
205     JTProbs[Clusters[I].MBB] = BranchProbability::getZero();
206 
207   for (unsigned I = First; I <= Last; ++I) {
208     assert(Clusters[I].Kind == CC_Range);
209     Prob += Clusters[I].Prob;
210     const APInt &Low = Clusters[I].Low->getValue();
211     const APInt &High = Clusters[I].High->getValue();
212     NumCmps += (Low == High) ? 1 : 2;
213     if (I != First) {
214       // Fill the gap between this and the previous cluster.
215       const APInt &PreviousHigh = Clusters[I - 1].High->getValue();
216       assert(PreviousHigh.slt(Low));
217       uint64_t Gap = (Low - PreviousHigh).getLimitedValue() - 1;
218       for (uint64_t J = 0; J < Gap; J++)
219         Table.push_back(DefaultMBB);
220     }
221     uint64_t ClusterSize = (High - Low).getLimitedValue() + 1;
222     for (uint64_t J = 0; J < ClusterSize; ++J)
223       Table.push_back(Clusters[I].MBB);
224     JTProbs[Clusters[I].MBB] += Clusters[I].Prob;
225   }
226 
227   unsigned NumDests = JTProbs.size();
228   if (TLI->isSuitableForBitTests(NumDests, NumCmps,
229                                  Clusters[First].Low->getValue(),
230                                  Clusters[Last].High->getValue(), *DL)) {
231     // Clusters[First..Last] should be lowered as bit tests instead.
232     return false;
233   }
234 
235   // Create the MBB that will load from and jump through the table.
236   // Note: We create it here, but it's not inserted into the function yet.
237   MachineFunction *CurMF = FuncInfo.MF;
238   MachineBasicBlock *JumpTableMBB =
239       CurMF->CreateMachineBasicBlock(SI->getParent());
240 
241   // Add successors. Note: use table order for determinism.
242   SmallPtrSet<MachineBasicBlock *, 8> Done;
243   for (MachineBasicBlock *Succ : Table) {
244     if (Done.count(Succ))
245       continue;
246     addSuccessorWithProb(JumpTableMBB, Succ, JTProbs[Succ]);
247     Done.insert(Succ);
248   }
249   JumpTableMBB->normalizeSuccProbs();
250 
251   unsigned JTI = CurMF->getOrCreateJumpTableInfo(TLI->getJumpTableEncoding())
252                      ->createJumpTableIndex(Table);
253 
254   // Set up the jump table info.
255   JumpTable JT(-1U, JTI, JumpTableMBB, nullptr);
256   JumpTableHeader JTH(Clusters[First].Low->getValue(),
257                       Clusters[Last].High->getValue(), SI->getCondition(),
258                       nullptr, false);
259   JTCases.emplace_back(std::move(JTH), std::move(JT));
260 
261   JTCluster = CaseCluster::jumpTable(Clusters[First].Low, Clusters[Last].High,
262                                      JTCases.size() - 1, Prob);
263   return true;
264 }
265 
266 void SwitchCG::SwitchLowering::findBitTestClusters(CaseClusterVector &Clusters,
267                                                    const SwitchInst *SI) {
268   // Partition Clusters into as few subsets as possible, where each subset has a
269   // range that fits in a machine word and has <= 3 unique destinations.
270 
271 #ifndef NDEBUG
272   // Clusters must be sorted and contain Range or JumpTable clusters.
273   assert(!Clusters.empty());
274   assert(Clusters[0].Kind == CC_Range || Clusters[0].Kind == CC_JumpTable);
275   for (const CaseCluster &C : Clusters)
276     assert(C.Kind == CC_Range || C.Kind == CC_JumpTable);
277   for (unsigned i = 1; i < Clusters.size(); ++i)
278     assert(Clusters[i-1].High->getValue().slt(Clusters[i].Low->getValue()));
279 #endif
280 
281   // The algorithm below is not suitable for -O0.
282   if (TM->getOptLevel() == CodeGenOpt::None)
283     return;
284 
285   // If target does not have legal shift left, do not emit bit tests at all.
286   EVT PTy = TLI->getPointerTy(*DL);
287   if (!TLI->isOperationLegal(ISD::SHL, PTy))
288     return;
289 
290   int BitWidth = PTy.getSizeInBits();
291   const int64_t N = Clusters.size();
292 
293   // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1].
294   SmallVector<unsigned, 8> MinPartitions(N);
295   // LastElement[i] is the last element of the partition starting at i.
296   SmallVector<unsigned, 8> LastElement(N);
297 
298   // FIXME: This might not be the best algorithm for finding bit test clusters.
299 
300   // Base case: There is only one way to partition Clusters[N-1].
301   MinPartitions[N - 1] = 1;
302   LastElement[N - 1] = N - 1;
303 
304   // Note: loop indexes are signed to avoid underflow.
305   for (int64_t i = N - 2; i >= 0; --i) {
306     // Find optimal partitioning of Clusters[i..N-1].
307     // Baseline: Put Clusters[i] into a partition on its own.
308     MinPartitions[i] = MinPartitions[i + 1] + 1;
309     LastElement[i] = i;
310 
311     // Search for a solution that results in fewer partitions.
312     // Note: the search is limited by BitWidth, reducing time complexity.
313     for (int64_t j = std::min(N - 1, i + BitWidth - 1); j > i; --j) {
314       // Try building a partition from Clusters[i..j].
315 
316       // Check the range.
317       if (!TLI->rangeFitsInWord(Clusters[i].Low->getValue(),
318                                 Clusters[j].High->getValue(), *DL))
319         continue;
320 
321       // Check nbr of destinations and cluster types.
322       // FIXME: This works, but doesn't seem very efficient.
323       bool RangesOnly = true;
324       BitVector Dests(FuncInfo.MF->getNumBlockIDs());
325       for (int64_t k = i; k <= j; k++) {
326         if (Clusters[k].Kind != CC_Range) {
327           RangesOnly = false;
328           break;
329         }
330         Dests.set(Clusters[k].MBB->getNumber());
331       }
332       if (!RangesOnly || Dests.count() > 3)
333         break;
334 
335       // Check if it's a better partition.
336       unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);
337       if (NumPartitions < MinPartitions[i]) {
338         // Found a better partition.
339         MinPartitions[i] = NumPartitions;
340         LastElement[i] = j;
341       }
342     }
343   }
344 
345   // Iterate over the partitions, replacing with bit-test clusters in-place.
346   unsigned DstIndex = 0;
347   for (unsigned First = 0, Last; First < N; First = Last + 1) {
348     Last = LastElement[First];
349     assert(First <= Last);
350     assert(DstIndex <= First);
351 
352     CaseCluster BitTestCluster;
353     if (buildBitTests(Clusters, First, Last, SI, BitTestCluster)) {
354       Clusters[DstIndex++] = BitTestCluster;
355     } else {
356       size_t NumClusters = Last - First + 1;
357       std::memmove(&Clusters[DstIndex], &Clusters[First],
358                    sizeof(Clusters[0]) * NumClusters);
359       DstIndex += NumClusters;
360     }
361   }
362   Clusters.resize(DstIndex);
363 }
364 
365 bool SwitchCG::SwitchLowering::buildBitTests(CaseClusterVector &Clusters,
366                                              unsigned First, unsigned Last,
367                                              const SwitchInst *SI,
368                                              CaseCluster &BTCluster) {
369   assert(First <= Last);
370   if (First == Last)
371     return false;
372 
373   BitVector Dests(FuncInfo.MF->getNumBlockIDs());
374   unsigned NumCmps = 0;
375   for (int64_t I = First; I <= Last; ++I) {
376     assert(Clusters[I].Kind == CC_Range);
377     Dests.set(Clusters[I].MBB->getNumber());
378     NumCmps += (Clusters[I].Low == Clusters[I].High) ? 1 : 2;
379   }
380   unsigned NumDests = Dests.count();
381 
382   APInt Low = Clusters[First].Low->getValue();
383   APInt High = Clusters[Last].High->getValue();
384   assert(Low.slt(High));
385 
386   if (!TLI->isSuitableForBitTests(NumDests, NumCmps, Low, High, *DL))
387     return false;
388 
389   APInt LowBound;
390   APInt CmpRange;
391 
392   const int BitWidth = TLI->getPointerTy(*DL).getSizeInBits();
393   assert(TLI->rangeFitsInWord(Low, High, *DL) &&
394          "Case range must fit in bit mask!");
395 
396   // Check if the clusters cover a contiguous range such that no value in the
397   // range will jump to the default statement.
398   bool ContiguousRange = true;
399   for (int64_t I = First + 1; I <= Last; ++I) {
400     if (Clusters[I].Low->getValue() != Clusters[I - 1].High->getValue() + 1) {
401       ContiguousRange = false;
402       break;
403     }
404   }
405 
406   if (Low.isStrictlyPositive() && High.slt(BitWidth)) {
407     // Optimize the case where all the case values fit in a word without having
408     // to subtract minValue. In this case, we can optimize away the subtraction.
409     LowBound = APInt::getNullValue(Low.getBitWidth());
410     CmpRange = High;
411     ContiguousRange = false;
412   } else {
413     LowBound = Low;
414     CmpRange = High - Low;
415   }
416 
417   CaseBitsVector CBV;
418   auto TotalProb = BranchProbability::getZero();
419   for (unsigned i = First; i <= Last; ++i) {
420     // Find the CaseBits for this destination.
421     unsigned j;
422     for (j = 0; j < CBV.size(); ++j)
423       if (CBV[j].BB == Clusters[i].MBB)
424         break;
425     if (j == CBV.size())
426       CBV.push_back(
427           CaseBits(0, Clusters[i].MBB, 0, BranchProbability::getZero()));
428     CaseBits *CB = &CBV[j];
429 
430     // Update Mask, Bits and ExtraProb.
431     uint64_t Lo = (Clusters[i].Low->getValue() - LowBound).getZExtValue();
432     uint64_t Hi = (Clusters[i].High->getValue() - LowBound).getZExtValue();
433     assert(Hi >= Lo && Hi < 64 && "Invalid bit case!");
434     CB->Mask |= (-1ULL >> (63 - (Hi - Lo))) << Lo;
435     CB->Bits += Hi - Lo + 1;
436     CB->ExtraProb += Clusters[i].Prob;
437     TotalProb += Clusters[i].Prob;
438   }
439 
440   BitTestInfo BTI;
441   llvm::sort(CBV, [](const CaseBits &a, const CaseBits &b) {
442     // Sort by probability first, number of bits second, bit mask third.
443     if (a.ExtraProb != b.ExtraProb)
444       return a.ExtraProb > b.ExtraProb;
445     if (a.Bits != b.Bits)
446       return a.Bits > b.Bits;
447     return a.Mask < b.Mask;
448   });
449 
450   for (auto &CB : CBV) {
451     MachineBasicBlock *BitTestBB =
452         FuncInfo.MF->CreateMachineBasicBlock(SI->getParent());
453     BTI.push_back(BitTestCase(CB.Mask, BitTestBB, CB.BB, CB.ExtraProb));
454   }
455   BitTestCases.emplace_back(std::move(LowBound), std::move(CmpRange),
456                             SI->getCondition(), -1U, MVT::Other, false,
457                             ContiguousRange, nullptr, nullptr, std::move(BTI),
458                             TotalProb);
459 
460   BTCluster = CaseCluster::bitTests(Clusters[First].Low, Clusters[Last].High,
461                                     BitTestCases.size() - 1, TotalProb);
462   return true;
463 }
464 
465 void SwitchCG::sortAndRangeify(CaseClusterVector &Clusters) {
466 #ifndef NDEBUG
467   for (const CaseCluster &CC : Clusters)
468     assert(CC.Low == CC.High && "Input clusters must be single-case");
469 #endif
470 
471   llvm::sort(Clusters, [](const CaseCluster &a, const CaseCluster &b) {
472     return a.Low->getValue().slt(b.Low->getValue());
473   });
474 
475   // Merge adjacent clusters with the same destination.
476   const unsigned N = Clusters.size();
477   unsigned DstIndex = 0;
478   for (unsigned SrcIndex = 0; SrcIndex < N; ++SrcIndex) {
479     CaseCluster &CC = Clusters[SrcIndex];
480     const ConstantInt *CaseVal = CC.Low;
481     MachineBasicBlock *Succ = CC.MBB;
482 
483     if (DstIndex != 0 && Clusters[DstIndex - 1].MBB == Succ &&
484         (CaseVal->getValue() - Clusters[DstIndex - 1].High->getValue()) == 1) {
485       // If this case has the same successor and is a neighbour, merge it into
486       // the previous cluster.
487       Clusters[DstIndex - 1].High = CaseVal;
488       Clusters[DstIndex - 1].Prob += CC.Prob;
489     } else {
490       std::memmove(&Clusters[DstIndex++], &Clusters[SrcIndex],
491                    sizeof(Clusters[SrcIndex]));
492     }
493   }
494   Clusters.resize(DstIndex);
495 }
496