1#!/usr/bin/env python3
2# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
3
4import os
5import random
6import sys
7
8from block_cache_pysim import (
9    ARCCache,
10    CacheEntry,
11    GDSizeCache,
12    HashTable,
13    HyperbolicPolicy,
14    LFUPolicy,
15    LinUCBCache,
16    LRUCache,
17    LRUPolicy,
18    MRUPolicy,
19    OPTCache,
20    OPTCacheEntry,
21    ThompsonSamplingCache,
22    TraceCache,
23    TraceRecord,
24    create_cache,
25    kMicrosInSecond,
26    kSampleSize,
27    run,
28)
29
30
31def test_hash_table():
32    print("Test hash table")
33    table = HashTable()
34    data_size = 10000
35    for i in range(data_size):
36        table.insert("k{}".format(i), i, "v{}".format(i))
37    for i in range(data_size):
38        assert table.lookup("k{}".format(i), i) is not None
39    for i in range(data_size):
40        table.delete("k{}".format(i), i)
41    for i in range(data_size):
42        assert table.lookup("k{}".format(i), i) is None
43
44    truth_map = {}
45    n = 1000000
46    records = 100
47    for i in range(n):
48        key_id = random.randint(0, records)
49        v = random.randint(0, records)
50        key = "k{}".format(key_id)
51        value = CacheEntry(v, v, v, v, v, v, v)
52        action = random.randint(0, 10)
53        assert len(truth_map) == table.elements, "{} {} {}".format(
54            len(truth_map), table.elements, i
55        )
56        if action <= 8:
57            if key in truth_map:
58                assert table.lookup(key, key_id) is not None
59                assert truth_map[key].value_size == table.lookup(key, key_id).value_size
60            else:
61                assert table.lookup(key, key_id) is None
62            table.insert(key, key_id, value)
63            truth_map[key] = value
64        else:
65            deleted = table.delete(key, key_id)
66            if deleted:
67                assert key in truth_map
68            if key in truth_map:
69                del truth_map[key]
70
71    # Check all keys are unique in the sample set.
72    for _i in range(10):
73        samples = table.random_sample(kSampleSize)
74        unique_keys = {}
75        for sample in samples:
76            unique_keys[sample.key] = True
77        assert len(samples) == len(unique_keys)
78
79    assert len(table) == len(truth_map)
80    for key in truth_map:
81        assert table.lookup(key, int(key[1:])) is not None
82        assert truth_map[key].value_size == table.lookup(key, int(key[1:])).value_size
83    print("Test hash table: Success")
84
85
86def assert_metrics(cache, expected_value, expected_value_size=1, custom_hashtable=True):
87    assert cache.used_size == expected_value[0], "Expected {}, Actual {}".format(
88        expected_value[0], cache.used_size
89    )
90    assert (
91        cache.miss_ratio_stats.num_accesses == expected_value[1]
92    ), "Expected {}, Actual {}".format(
93        expected_value[1], cache.miss_ratio_stats.num_accesses
94    )
95    assert (
96        cache.miss_ratio_stats.num_misses == expected_value[2]
97    ), "Expected {}, Actual {}".format(
98        expected_value[2], cache.miss_ratio_stats.num_misses
99    )
100    assert len(cache.table) == len(expected_value[3]) + len(
101        expected_value[4]
102    ), "Expected {}, Actual {}".format(
103        len(expected_value[3]) + len(expected_value[4]), cache.table.elements
104    )
105    for expeceted_k in expected_value[3]:
106        if custom_hashtable:
107            val = cache.table.lookup("b{}".format(expeceted_k), expeceted_k)
108        else:
109            val = cache.table["b{}".format(expeceted_k)]
110        assert val is not None, "Expected {} Actual: Not Exist {}, Table: {}".format(
111            expeceted_k, expected_value, cache.table
112        )
113        assert val.value_size == expected_value_size
114    for expeceted_k in expected_value[4]:
115        if custom_hashtable:
116            val = cache.table.lookup("g0-{}".format(expeceted_k), expeceted_k)
117        else:
118            val = cache.table["g0-{}".format(expeceted_k)]
119        assert val is not None
120        assert val.value_size == expected_value_size
121
122
123# Access k1, k1, k2, k3, k3, k3, k4
124# When k4 is inserted,
125#   LRU should evict k1.
126#   LFU should evict k2.
127#   MRU should evict k3.
128def test_cache(cache, expected_value, custom_hashtable=True):
129    k1 = TraceRecord(
130        access_time=0,
131        block_id=1,
132        block_type=1,
133        block_size=1,
134        cf_id=0,
135        cf_name="",
136        level=0,
137        fd=0,
138        caller=1,
139        no_insert=0,
140        get_id=1,
141        key_id=1,
142        kv_size=5,
143        is_hit=1,
144        referenced_key_exist_in_block=1,
145        num_keys_in_block=0,
146        table_id=0,
147        seq_number=0,
148        block_key_size=0,
149        key_size=0,
150        block_offset_in_file=0,
151        next_access_seq_no=0,
152    )
153    k2 = TraceRecord(
154        access_time=1,
155        block_id=2,
156        block_type=1,
157        block_size=1,
158        cf_id=0,
159        cf_name="",
160        level=0,
161        fd=0,
162        caller=1,
163        no_insert=0,
164        get_id=1,
165        key_id=1,
166        kv_size=5,
167        is_hit=1,
168        referenced_key_exist_in_block=1,
169        num_keys_in_block=0,
170        table_id=0,
171        seq_number=0,
172        block_key_size=0,
173        key_size=0,
174        block_offset_in_file=0,
175        next_access_seq_no=0,
176    )
177    k3 = TraceRecord(
178        access_time=2,
179        block_id=3,
180        block_type=1,
181        block_size=1,
182        cf_id=0,
183        cf_name="",
184        level=0,
185        fd=0,
186        caller=1,
187        no_insert=0,
188        get_id=1,
189        key_id=1,
190        kv_size=5,
191        is_hit=1,
192        referenced_key_exist_in_block=1,
193        num_keys_in_block=0,
194        table_id=0,
195        seq_number=0,
196        block_key_size=0,
197        key_size=0,
198        block_offset_in_file=0,
199        next_access_seq_no=0,
200    )
201    k4 = TraceRecord(
202        access_time=3,
203        block_id=4,
204        block_type=1,
205        block_size=1,
206        cf_id=0,
207        cf_name="",
208        level=0,
209        fd=0,
210        caller=1,
211        no_insert=0,
212        get_id=1,
213        key_id=1,
214        kv_size=5,
215        is_hit=1,
216        referenced_key_exist_in_block=1,
217        num_keys_in_block=0,
218        table_id=0,
219        seq_number=0,
220        block_key_size=0,
221        key_size=0,
222        block_offset_in_file=0,
223        next_access_seq_no=0,
224    )
225    sequence = [k1, k1, k2, k3, k3, k3]
226    index = 0
227    expected_values = []
228    # Access k1, miss.
229    expected_values.append([1, 1, 1, [1], []])
230    # Access k1, hit.
231    expected_values.append([1, 2, 1, [1], []])
232    # Access k2, miss.
233    expected_values.append([2, 3, 2, [1, 2], []])
234    # Access k3, miss.
235    expected_values.append([3, 4, 3, [1, 2, 3], []])
236    # Access k3, hit.
237    expected_values.append([3, 5, 3, [1, 2, 3], []])
238    # Access k3, hit.
239    expected_values.append([3, 6, 3, [1, 2, 3], []])
240    access_time = 0
241    for access in sequence:
242        access.access_time = access_time
243        cache.access(access)
244        assert_metrics(
245            cache,
246            expected_values[index],
247            expected_value_size=1,
248            custom_hashtable=custom_hashtable,
249        )
250        access_time += 1
251        index += 1
252    k4.access_time = access_time
253    cache.access(k4)
254    assert_metrics(
255        cache, expected_value, expected_value_size=1, custom_hashtable=custom_hashtable
256    )
257
258
259def test_lru_cache(cache, custom_hashtable):
260    print("Test LRU cache")
261    # Access k4, miss. evict k1
262    test_cache(cache, [3, 7, 4, [2, 3, 4], []], custom_hashtable)
263    print("Test LRU cache: Success")
264
265
266def test_mru_cache():
267    print("Test MRU cache")
268    policies = []
269    policies.append(MRUPolicy())
270    # Access k4, miss. evict k3
271    test_cache(
272        ThompsonSamplingCache(3, False, policies, cost_class_label=None),
273        [3, 7, 4, [1, 2, 4], []],
274    )
275    print("Test MRU cache: Success")
276
277
278def test_lfu_cache():
279    print("Test LFU cache")
280    policies = []
281    policies.append(LFUPolicy())
282    # Access k4, miss. evict k2
283    test_cache(
284        ThompsonSamplingCache(3, False, policies, cost_class_label=None),
285        [3, 7, 4, [1, 3, 4], []],
286    )
287    print("Test LFU cache: Success")
288
289
290def test_mix(cache):
291    print("Test Mix {} cache".format(cache.cache_name()))
292    n = 100000
293    records = 100
294    block_size_table = {}
295    trace_num_misses = 0
296    for i in range(n):
297        key_id = random.randint(0, records)
298        vs = random.randint(0, 10)
299        now = i * kMicrosInSecond
300        block_size = vs
301        if key_id in block_size_table:
302            block_size = block_size_table[key_id]
303        else:
304            block_size_table[key_id] = block_size
305        is_hit = key_id % 2
306        if is_hit == 0:
307            trace_num_misses += 1
308        k = TraceRecord(
309            access_time=now,
310            block_id=key_id,
311            block_type=1,
312            block_size=block_size,
313            cf_id=0,
314            cf_name="",
315            level=0,
316            fd=0,
317            caller=1,
318            no_insert=0,
319            get_id=key_id,
320            key_id=key_id,
321            kv_size=5,
322            is_hit=is_hit,
323            referenced_key_exist_in_block=1,
324            num_keys_in_block=0,
325            table_id=0,
326            seq_number=0,
327            block_key_size=0,
328            key_size=0,
329            block_offset_in_file=0,
330            next_access_seq_no=vs,
331        )
332        cache.access(k)
333    assert cache.miss_ratio_stats.miss_ratio() > 0
334    if cache.cache_name() == "Trace":
335        assert cache.miss_ratio_stats.num_accesses == n
336        assert cache.miss_ratio_stats.num_misses == trace_num_misses
337    else:
338        assert cache.used_size <= cache.cache_size
339        all_values = cache.table.values()
340        cached_size = 0
341        for value in all_values:
342            cached_size += value.value_size
343        assert cached_size == cache.used_size, "Expeced {} Actual {}".format(
344            cache.used_size, cached_size
345        )
346    print("Test Mix {} cache: Success".format(cache.cache_name()))
347
348
349def test_end_to_end():
350    print("Test All caches")
351    n = 100000
352    nblocks = 1000
353    block_size = 16 * 1024
354    ncfs = 7
355    nlevels = 6
356    nfds = 100000
357    trace_file_path = "test_trace"
358    # All blocks are of the same size so that OPT must achieve the lowest miss
359    # ratio.
360    with open(trace_file_path, "w+") as trace_file:
361        access_records = ""
362        for i in range(n):
363            key_id = random.randint(0, nblocks)
364            cf_id = random.randint(0, ncfs)
365            level = random.randint(0, nlevels)
366            fd = random.randint(0, nfds)
367            now = i * kMicrosInSecond
368            access_record = ""
369            access_record += "{},".format(now)
370            access_record += "{},".format(key_id)
371            access_record += "{},".format(9)  # block type
372            access_record += "{},".format(block_size)  # block size
373            access_record += "{},".format(cf_id)
374            access_record += "cf_{},".format(cf_id)
375            access_record += "{},".format(level)
376            access_record += "{},".format(fd)
377            access_record += "{},".format(key_id % 3)  # caller
378            access_record += "{},".format(0)  # no insert
379            access_record += "{},".format(i)  # get_id
380            access_record += "{},".format(i)  # key_id
381            access_record += "{},".format(100)  # kv_size
382            access_record += "{},".format(1)  # is_hit
383            access_record += "{},".format(1)  # referenced_key_exist_in_block
384            access_record += "{},".format(10)  # num_keys_in_block
385            access_record += "{},".format(1)  # table_id
386            access_record += "{},".format(0)  # seq_number
387            access_record += "{},".format(10)  # block key size
388            access_record += "{},".format(20)  # key size
389            access_record += "{},".format(0)  # block offset
390            access_record = access_record[:-1]
391            access_records += access_record + "\n"
392        trace_file.write(access_records)
393
394    print("Test All caches: Start testing caches")
395    cache_size = block_size * nblocks / 10
396    downsample_size = 1
397    cache_ms = {}
398    for cache_type in [
399        "ts",
400        "opt",
401        "lru",
402        "pylru",
403        "linucb",
404        "gdsize",
405        "pyccbt",
406        "pycctbbt",
407    ]:
408        cache = create_cache(cache_type, cache_size, downsample_size)
409        run(trace_file_path, cache_type, cache, 0, -1, "all")
410        cache_ms[cache_type] = cache
411        assert cache.miss_ratio_stats.num_accesses == n
412
413    for cache_type in cache_ms:
414        cache = cache_ms[cache_type]
415        ms = cache.miss_ratio_stats.miss_ratio()
416        assert ms <= 100.0 and ms >= 0.0
417        # OPT should perform the best.
418        assert cache_ms["opt"].miss_ratio_stats.miss_ratio() <= ms
419        assert cache.used_size <= cache.cache_size
420        all_values = cache.table.values()
421        cached_size = 0
422        for value in all_values:
423            cached_size += value.value_size
424        assert cached_size == cache.used_size, "Expeced {} Actual {}".format(
425            cache.used_size, cached_size
426        )
427        print("Test All {}: Success".format(cache.cache_name()))
428
429    os.remove(trace_file_path)
430    print("Test All: Success")
431
432
433def test_hybrid(cache):
434    print("Test {} cache".format(cache.cache_name()))
435    k = TraceRecord(
436        access_time=0,
437        block_id=1,
438        block_type=1,
439        block_size=1,
440        cf_id=0,
441        cf_name="",
442        level=0,
443        fd=0,
444        caller=1,
445        no_insert=0,
446        get_id=1,  # the first get request.
447        key_id=1,
448        kv_size=0,  # no size.
449        is_hit=1,
450        referenced_key_exist_in_block=1,
451        num_keys_in_block=0,
452        table_id=0,
453        seq_number=0,
454        block_key_size=0,
455        key_size=0,
456        block_offset_in_file=0,
457        next_access_seq_no=0,
458    )
459    cache.access(k)  # Expect a miss.
460    # used size, num accesses, num misses, hash table size, blocks, get keys.
461    assert_metrics(cache, [1, 1, 1, [1], []])
462    k.access_time += 1
463    k.kv_size = 1
464    k.block_id = 2
465    cache.access(k)  # k should be inserted.
466    assert_metrics(cache, [3, 2, 2, [1, 2], [1]])
467    k.access_time += 1
468    k.block_id = 3
469    cache.access(k)  # k should not be inserted again.
470    assert_metrics(cache, [4, 3, 3, [1, 2, 3], [1]])
471    # A second get request referencing the same key.
472    k.access_time += 1
473    k.get_id = 2
474    k.block_id = 4
475    k.kv_size = 0
476    cache.access(k)  # k should observe a hit. No block access.
477    assert_metrics(cache, [4, 4, 3, [1, 2, 3], [1]])
478
479    # A third get request searches three files, three different keys.
480    # And the second key observes a hit.
481    k.access_time += 1
482    k.kv_size = 1
483    k.get_id = 3
484    k.block_id = 3
485    k.key_id = 2
486    cache.access(k)  # k should observe a miss. block 3 observes a hit.
487    assert_metrics(cache, [5, 5, 3, [1, 2, 3], [1, 2]])
488
489    k.access_time += 1
490    k.kv_size = 1
491    k.get_id = 3
492    k.block_id = 4
493    k.kv_size = 1
494    k.key_id = 1
495    cache.access(k)  # k1 should observe a hit.
496    assert_metrics(cache, [5, 6, 3, [1, 2, 3], [1, 2]])
497
498    k.access_time += 1
499    k.kv_size = 1
500    k.get_id = 3
501    k.block_id = 4
502    k.kv_size = 1
503    k.key_id = 3
504    # k3 should observe a miss.
505    # However, as the get already complete, we should not access k3 any more.
506    cache.access(k)
507    assert_metrics(cache, [5, 7, 3, [1, 2, 3], [1, 2]])
508
509    # A fourth get request searches one file and two blocks. One row key.
510    k.access_time += 1
511    k.get_id = 4
512    k.block_id = 5
513    k.key_id = 4
514    k.kv_size = 1
515    cache.access(k)
516    assert_metrics(cache, [7, 8, 4, [1, 2, 3, 5], [1, 2, 4]])
517
518    # A bunch of insertions which evict cached row keys.
519    for i in range(6, 100):
520        k.access_time += 1
521        k.get_id = 0
522        k.block_id = i
523        cache.access(k)
524
525    k.get_id = 4
526    k.block_id = 100  # A different block.
527    k.key_id = 4  # Same row key and should not be inserted again.
528    k.kv_size = 1
529    cache.access(k)
530    assert_metrics(
531        cache, [kSampleSize, 103, 99, [i for i in range(101 - kSampleSize, 101)], []]
532    )
533    print("Test {} cache: Success".format(cache.cache_name()))
534
535
536def test_opt_cache():
537    print("Test OPT cache")
538    cache = OPTCache(3)
539    # seq:         0,  1,  2,  3,  4,  5,  6,  7,  8
540    # key:         k1, k2, k3, k4, k5, k6, k7, k1, k8
541    # next_access: 7,  19, 18, M,  M,  17, 16, 25, M
542    k = TraceRecord(
543        access_time=0,
544        block_id=1,
545        block_type=1,
546        block_size=1,
547        cf_id=0,
548        cf_name="",
549        level=0,
550        fd=0,
551        caller=1,
552        no_insert=0,
553        get_id=1,  # the first get request.
554        key_id=1,
555        kv_size=0,  # no size.
556        is_hit=1,
557        referenced_key_exist_in_block=1,
558        num_keys_in_block=0,
559        table_id=0,
560        seq_number=0,
561        block_key_size=0,
562        key_size=0,
563        block_offset_in_file=0,
564        next_access_seq_no=7,
565    )
566    cache.access(k)
567    assert_metrics(
568        cache, [1, 1, 1, [1], []], expected_value_size=1, custom_hashtable=False
569    )
570    k.access_time += 1
571    k.block_id = 2
572    k.next_access_seq_no = 19
573    cache.access(k)
574    assert_metrics(
575        cache, [2, 2, 2, [1, 2], []], expected_value_size=1, custom_hashtable=False
576    )
577    k.access_time += 1
578    k.block_id = 3
579    k.next_access_seq_no = 18
580    cache.access(k)
581    assert_metrics(
582        cache, [3, 3, 3, [1, 2, 3], []], expected_value_size=1, custom_hashtable=False
583    )
584    k.access_time += 1
585    k.block_id = 4
586    k.next_access_seq_no = sys.maxsize  # Never accessed again.
587    cache.access(k)
588    # Evict 2 since its next access 19 is the furthest in the future.
589    assert_metrics(
590        cache, [3, 4, 4, [1, 3, 4], []], expected_value_size=1, custom_hashtable=False
591    )
592    k.access_time += 1
593    k.block_id = 5
594    k.next_access_seq_no = sys.maxsize  # Never accessed again.
595    cache.access(k)
596    # Evict 4 since its next access MAXINT is the furthest in the future.
597    assert_metrics(
598        cache, [3, 5, 5, [1, 3, 5], []], expected_value_size=1, custom_hashtable=False
599    )
600    k.access_time += 1
601    k.block_id = 6
602    k.next_access_seq_no = 17
603    cache.access(k)
604    # Evict 5 since its next access MAXINT is the furthest in the future.
605    assert_metrics(
606        cache, [3, 6, 6, [1, 3, 6], []], expected_value_size=1, custom_hashtable=False
607    )
608    k.access_time += 1
609    k.block_id = 7
610    k.next_access_seq_no = 16
611    cache.access(k)
612    # Evict 3 since its next access 18 is the furthest in the future.
613    assert_metrics(
614        cache, [3, 7, 7, [1, 6, 7], []], expected_value_size=1, custom_hashtable=False
615    )
616    k.access_time += 1
617    k.block_id = 1
618    k.next_access_seq_no = 25
619    cache.access(k)
620    assert_metrics(
621        cache, [3, 8, 7, [1, 6, 7], []], expected_value_size=1, custom_hashtable=False
622    )
623    k.access_time += 1
624    k.block_id = 8
625    k.next_access_seq_no = sys.maxsize
626    cache.access(k)
627    # Evict 1 since its next access 25 is the furthest in the future.
628    assert_metrics(
629        cache, [3, 9, 8, [6, 7, 8], []], expected_value_size=1, custom_hashtable=False
630    )
631
632    # Insert a large kv pair to evict all keys.
633    k.access_time += 1
634    k.block_id = 10
635    k.block_size = 3
636    k.next_access_seq_no = sys.maxsize
637    cache.access(k)
638    assert_metrics(
639        cache, [3, 10, 9, [10], []], expected_value_size=3, custom_hashtable=False
640    )
641    print("Test OPT cache: Success")
642
643
644def test_trace_cache():
645    print("Test trace cache")
646    cache = TraceCache(0)
647    k = TraceRecord(
648        access_time=0,
649        block_id=1,
650        block_type=1,
651        block_size=1,
652        cf_id=0,
653        cf_name="",
654        level=0,
655        fd=0,
656        caller=1,
657        no_insert=0,
658        get_id=1,
659        key_id=1,
660        kv_size=0,
661        is_hit=1,
662        referenced_key_exist_in_block=1,
663        num_keys_in_block=0,
664        table_id=0,
665        seq_number=0,
666        block_key_size=0,
667        key_size=0,
668        block_offset_in_file=0,
669        next_access_seq_no=7,
670    )
671    cache.access(k)
672    assert cache.miss_ratio_stats.num_accesses == 1
673    assert cache.miss_ratio_stats.num_misses == 0
674    k.is_hit = 0
675    cache.access(k)
676    assert cache.miss_ratio_stats.num_accesses == 2
677    assert cache.miss_ratio_stats.num_misses == 1
678    print("Test trace cache: Success")
679
680
681if __name__ == "__main__":
682    test_hash_table()
683    test_trace_cache()
684    test_opt_cache()
685    test_lru_cache(
686        ThompsonSamplingCache(
687            3, enable_cache_row_key=0, policies=[LRUPolicy()], cost_class_label=None
688        ),
689        custom_hashtable=True,
690    )
691    test_lru_cache(LRUCache(3, enable_cache_row_key=0), custom_hashtable=False)
692    test_mru_cache()
693    test_lfu_cache()
694    test_hybrid(
695        ThompsonSamplingCache(
696            kSampleSize,
697            enable_cache_row_key=1,
698            policies=[LRUPolicy()],
699            cost_class_label=None,
700        )
701    )
702    test_hybrid(
703        LinUCBCache(
704            kSampleSize,
705            enable_cache_row_key=1,
706            policies=[LRUPolicy()],
707            cost_class_label=None,
708        )
709    )
710    for cache_type in [
711        "ts",
712        "opt",
713        "arc",
714        "pylfu",
715        "pymru",
716        "trace",
717        "pyhb",
718        "lru",
719        "pylru",
720        "linucb",
721        "gdsize",
722        "pycctbbt",
723        "pycctb",
724        "pyccbt",
725    ]:
726        for enable_row_cache in [0, 1, 2]:
727            cache_type_str = cache_type
728            if cache_type != "opt" and cache_type != "trace":
729                if enable_row_cache == 1:
730                    cache_type_str += "_hybrid"
731                elif enable_row_cache == 2:
732                    cache_type_str += "_hybridn"
733            test_mix(create_cache(cache_type_str, cache_size=100, downsample_size=1))
734    test_end_to_end()
735