1 //===- ReservoirSampler.cpp - Tests for the ReservoirSampler --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/FuzzMutate/Random.h" 10 #include "gtest/gtest.h" 11 #include <random> 12 13 using namespace llvm; 14 15 TEST(ReservoirSamplerTest, OneItem) { 16 std::mt19937 Rand; 17 auto Sampler = makeSampler(Rand, 7, 1); 18 ASSERT_FALSE(Sampler.isEmpty()); 19 ASSERT_EQ(7, Sampler.getSelection()); 20 } 21 22 TEST(ReservoirSamplerTest, NoWeight) { 23 std::mt19937 Rand; 24 auto Sampler = makeSampler(Rand, 7, 0); 25 ASSERT_TRUE(Sampler.isEmpty()); 26 } 27 28 TEST(ReservoirSamplerTest, Uniform) { 29 std::mt19937 Rand; 30 31 // Run three chi-squared tests to check that the distribution is reasonably 32 // uniform. 33 std::vector<int> Items = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; 34 35 int Failures = 0; 36 for (int Run = 0; Run < 3; ++Run) { 37 std::vector<int> Counts(Items.size(), 0); 38 39 // We need $np_s > 5$ at minimum, but we're better off going a couple of 40 // orders of magnitude larger. 41 int N = Items.size() * 5 * 100; 42 for (int I = 0; I < N; ++I) { 43 auto Sampler = makeSampler(Rand, Items); 44 Counts[Sampler.getSelection()] += 1; 45 } 46 47 // Knuth. TAOCP Vol. 2, 3.3.1 (8): 48 // $V = \frac{1}{n} \sum_{s=1}^{k} \left(\frac{Y_s^2}{p_s}\right) - n$ 49 double Ps = 1.0 / Items.size(); 50 double Sum = 0.0; 51 for (int Ys : Counts) 52 Sum += Ys * Ys / Ps; 53 double V = (Sum / N) - N; 54 55 assert(Items.size() == 10 && "Our chi-squared values assume 10 items"); 56 // Since we have 10 items, there are 9 degrees of freedom and the table of 57 // chi-squared values is as follows: 58 // 59 // | p=1% | 5% | 25% | 50% | 75% | 95% | 99% | 60 // v=9 | 2.088 | 3.325 | 5.899 | 8.343 | 11.39 | 16.92 | 21.67 | 61 // 62 // Check that we're in the likely range of results. 63 //if (V < 2.088 || V > 21.67) 64 if (V < 2.088 || V > 21.67) 65 ++Failures; 66 } 67 EXPECT_LT(Failures, 3) << "Non-uniform distribution?"; 68 } 69