1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #ifndef COUNT_NEW_H
10 #define COUNT_NEW_H
11
12 # include <cstdlib>
13 # include <cassert>
14 # include <new>
15
16 #include "test_macros.h"
17
18 #if defined(TEST_HAS_SANITIZERS)
19 #define DISABLE_NEW_COUNT
20 #endif
21
22 namespace detail
23 {
24 TEST_NORETURN
throw_bad_alloc_helper()25 inline void throw_bad_alloc_helper() {
26 #ifndef TEST_HAS_NO_EXCEPTIONS
27 throw std::bad_alloc();
28 #else
29 std::abort();
30 #endif
31 }
32 }
33
34 class MemCounter
35 {
36 public:
37 // Make MemCounter super hard to accidentally construct or copy.
38 class MemCounterCtorArg_ {};
MemCounter(MemCounterCtorArg_)39 explicit MemCounter(MemCounterCtorArg_) { reset(); }
40
41 private:
42 MemCounter(MemCounter const &);
43 MemCounter & operator=(MemCounter const &);
44
45 public:
46 // All checks return true when disable_checking is enabled.
47 static const bool disable_checking;
48
49 // Disallow any allocations from occurring. Useful for testing that
50 // code doesn't perform any allocations.
51 bool disable_allocations;
52
53 // number of allocations to throw after. Default (unsigned)-1. If
54 // throw_after has the default value it will never be decremented.
55 static const unsigned never_throw_value = static_cast<unsigned>(-1);
56 unsigned throw_after;
57
58 int outstanding_new;
59 int new_called;
60 int delete_called;
61 int aligned_new_called;
62 int aligned_delete_called;
63 std::size_t last_new_size;
64 std::size_t last_new_align;
65 std::size_t last_delete_align;
66
67 int outstanding_array_new;
68 int new_array_called;
69 int delete_array_called;
70 int aligned_new_array_called;
71 int aligned_delete_array_called;
72 std::size_t last_new_array_size;
73 std::size_t last_new_array_align;
74 std::size_t last_delete_array_align;
75
76 public:
newCalled(std::size_t s)77 void newCalled(std::size_t s)
78 {
79 assert(disable_allocations == false);
80 assert(s);
81 if (throw_after == 0) {
82 throw_after = never_throw_value;
83 detail::throw_bad_alloc_helper();
84 } else if (throw_after != never_throw_value) {
85 --throw_after;
86 }
87 ++new_called;
88 ++outstanding_new;
89 last_new_size = s;
90 }
91
alignedNewCalled(std::size_t s,std::size_t a)92 void alignedNewCalled(std::size_t s, std::size_t a) {
93 newCalled(s);
94 ++aligned_new_called;
95 last_new_align = a;
96 }
97
deleteCalled(void * p)98 void deleteCalled(void * p)
99 {
100 assert(p);
101 --outstanding_new;
102 ++delete_called;
103 }
104
alignedDeleteCalled(void * p,std::size_t a)105 void alignedDeleteCalled(void *p, std::size_t a) {
106 deleteCalled(p);
107 ++aligned_delete_called;
108 last_delete_align = a;
109 }
110
newArrayCalled(std::size_t s)111 void newArrayCalled(std::size_t s)
112 {
113 assert(disable_allocations == false);
114 assert(s);
115 if (throw_after == 0) {
116 throw_after = never_throw_value;
117 detail::throw_bad_alloc_helper();
118 } else {
119 // don't decrement throw_after here. newCalled will end up doing that.
120 }
121 ++outstanding_array_new;
122 ++new_array_called;
123 last_new_array_size = s;
124 }
125
alignedNewArrayCalled(std::size_t s,std::size_t a)126 void alignedNewArrayCalled(std::size_t s, std::size_t a) {
127 newArrayCalled(s);
128 ++aligned_new_array_called;
129 last_new_array_align = a;
130 }
131
deleteArrayCalled(void * p)132 void deleteArrayCalled(void * p)
133 {
134 assert(p);
135 --outstanding_array_new;
136 ++delete_array_called;
137 }
138
alignedDeleteArrayCalled(void * p,std::size_t a)139 void alignedDeleteArrayCalled(void * p, std::size_t a) {
140 deleteArrayCalled(p);
141 ++aligned_delete_array_called;
142 last_delete_array_align = a;
143 }
144
disableAllocations()145 void disableAllocations()
146 {
147 disable_allocations = true;
148 }
149
enableAllocations()150 void enableAllocations()
151 {
152 disable_allocations = false;
153 }
154
reset()155 void reset()
156 {
157 disable_allocations = false;
158 throw_after = never_throw_value;
159
160 outstanding_new = 0;
161 new_called = 0;
162 delete_called = 0;
163 aligned_new_called = 0;
164 aligned_delete_called = 0;
165 last_new_size = 0;
166 last_new_align = 0;
167
168 outstanding_array_new = 0;
169 new_array_called = 0;
170 delete_array_called = 0;
171 aligned_new_array_called = 0;
172 aligned_delete_array_called = 0;
173 last_new_array_size = 0;
174 last_new_array_align = 0;
175 }
176
177 public:
checkOutstandingNewEq(int n)178 bool checkOutstandingNewEq(int n) const
179 {
180 return disable_checking || n == outstanding_new;
181 }
182
checkOutstandingNewNotEq(int n)183 bool checkOutstandingNewNotEq(int n) const
184 {
185 return disable_checking || n != outstanding_new;
186 }
187
checkNewCalledEq(int n)188 bool checkNewCalledEq(int n) const
189 {
190 return disable_checking || n == new_called;
191 }
192
checkNewCalledNotEq(int n)193 bool checkNewCalledNotEq(int n) const
194 {
195 return disable_checking || n != new_called;
196 }
197
checkNewCalledGreaterThan(int n)198 bool checkNewCalledGreaterThan(int n) const
199 {
200 return disable_checking || new_called > n;
201 }
202
checkDeleteCalledEq(int n)203 bool checkDeleteCalledEq(int n) const
204 {
205 return disable_checking || n == delete_called;
206 }
207
checkDeleteCalledNotEq(int n)208 bool checkDeleteCalledNotEq(int n) const
209 {
210 return disable_checking || n != delete_called;
211 }
212
checkAlignedNewCalledEq(int n)213 bool checkAlignedNewCalledEq(int n) const
214 {
215 return disable_checking || n == aligned_new_called;
216 }
217
checkAlignedNewCalledNotEq(int n)218 bool checkAlignedNewCalledNotEq(int n) const
219 {
220 return disable_checking || n != aligned_new_called;
221 }
222
checkAlignedNewCalledGreaterThan(int n)223 bool checkAlignedNewCalledGreaterThan(int n) const
224 {
225 return disable_checking || aligned_new_called > n;
226 }
227
checkAlignedDeleteCalledEq(int n)228 bool checkAlignedDeleteCalledEq(int n) const
229 {
230 return disable_checking || n == aligned_delete_called;
231 }
232
checkAlignedDeleteCalledNotEq(int n)233 bool checkAlignedDeleteCalledNotEq(int n) const
234 {
235 return disable_checking || n != aligned_delete_called;
236 }
237
checkLastNewSizeEq(std::size_t n)238 bool checkLastNewSizeEq(std::size_t n) const
239 {
240 return disable_checking || n == last_new_size;
241 }
242
checkLastNewSizeNotEq(std::size_t n)243 bool checkLastNewSizeNotEq(std::size_t n) const
244 {
245 return disable_checking || n != last_new_size;
246 }
247
checkLastNewAlignEq(std::size_t n)248 bool checkLastNewAlignEq(std::size_t n) const
249 {
250 return disable_checking || n == last_new_align;
251 }
252
checkLastNewAlignNotEq(std::size_t n)253 bool checkLastNewAlignNotEq(std::size_t n) const
254 {
255 return disable_checking || n != last_new_align;
256 }
257
checkLastDeleteAlignEq(std::size_t n)258 bool checkLastDeleteAlignEq(std::size_t n) const
259 {
260 return disable_checking || n == last_delete_align;
261 }
262
checkLastDeleteAlignNotEq(std::size_t n)263 bool checkLastDeleteAlignNotEq(std::size_t n) const
264 {
265 return disable_checking || n != last_delete_align;
266 }
267
checkOutstandingArrayNewEq(int n)268 bool checkOutstandingArrayNewEq(int n) const
269 {
270 return disable_checking || n == outstanding_array_new;
271 }
272
checkOutstandingArrayNewNotEq(int n)273 bool checkOutstandingArrayNewNotEq(int n) const
274 {
275 return disable_checking || n != outstanding_array_new;
276 }
277
checkNewArrayCalledEq(int n)278 bool checkNewArrayCalledEq(int n) const
279 {
280 return disable_checking || n == new_array_called;
281 }
282
checkNewArrayCalledNotEq(int n)283 bool checkNewArrayCalledNotEq(int n) const
284 {
285 return disable_checking || n != new_array_called;
286 }
287
checkDeleteArrayCalledEq(int n)288 bool checkDeleteArrayCalledEq(int n) const
289 {
290 return disable_checking || n == delete_array_called;
291 }
292
checkDeleteArrayCalledNotEq(int n)293 bool checkDeleteArrayCalledNotEq(int n) const
294 {
295 return disable_checking || n != delete_array_called;
296 }
297
checkAlignedNewArrayCalledEq(int n)298 bool checkAlignedNewArrayCalledEq(int n) const
299 {
300 return disable_checking || n == aligned_new_array_called;
301 }
302
checkAlignedNewArrayCalledNotEq(int n)303 bool checkAlignedNewArrayCalledNotEq(int n) const
304 {
305 return disable_checking || n != aligned_new_array_called;
306 }
307
checkAlignedNewArrayCalledGreaterThan(int n)308 bool checkAlignedNewArrayCalledGreaterThan(int n) const
309 {
310 return disable_checking || aligned_new_array_called > n;
311 }
312
checkAlignedDeleteArrayCalledEq(int n)313 bool checkAlignedDeleteArrayCalledEq(int n) const
314 {
315 return disable_checking || n == aligned_delete_array_called;
316 }
317
checkAlignedDeleteArrayCalledNotEq(int n)318 bool checkAlignedDeleteArrayCalledNotEq(int n) const
319 {
320 return disable_checking || n != aligned_delete_array_called;
321 }
322
checkLastNewArraySizeEq(std::size_t n)323 bool checkLastNewArraySizeEq(std::size_t n) const
324 {
325 return disable_checking || n == last_new_array_size;
326 }
327
checkLastNewArraySizeNotEq(std::size_t n)328 bool checkLastNewArraySizeNotEq(std::size_t n) const
329 {
330 return disable_checking || n != last_new_array_size;
331 }
332
checkLastNewArrayAlignEq(std::size_t n)333 bool checkLastNewArrayAlignEq(std::size_t n) const
334 {
335 return disable_checking || n == last_new_array_align;
336 }
337
checkLastNewArrayAlignNotEq(std::size_t n)338 bool checkLastNewArrayAlignNotEq(std::size_t n) const
339 {
340 return disable_checking || n != last_new_array_align;
341 }
342 };
343
344 #ifdef DISABLE_NEW_COUNT
345 const bool MemCounter::disable_checking = true;
346 #else
347 const bool MemCounter::disable_checking = false;
348 #endif
349
350 TEST_DIAGNOSTIC_PUSH
351 TEST_MSVC_DIAGNOSTIC_IGNORED(4640) // '%s' construction of local static object is not thread safe (/Zc:threadSafeInit-)
getGlobalMemCounter()352 inline MemCounter* getGlobalMemCounter() {
353 static MemCounter counter((MemCounter::MemCounterCtorArg_()));
354 return &counter;
355 }
356 TEST_DIAGNOSTIC_POP
357
358 MemCounter &globalMemCounter = *getGlobalMemCounter();
359
360 #ifndef DISABLE_NEW_COUNT
new(std::size_t s)361 void* operator new(std::size_t s) TEST_THROW_SPEC(std::bad_alloc)
362 {
363 getGlobalMemCounter()->newCalled(s);
364 void* ret = std::malloc(s);
365 if (ret == nullptr)
366 detail::throw_bad_alloc_helper();
367 return ret;
368 }
369
delete(void * p)370 void operator delete(void* p) TEST_NOEXCEPT
371 {
372 getGlobalMemCounter()->deleteCalled(p);
373 std::free(p);
374 }
375
TEST_THROW_SPEC(std::bad_alloc)376 void* operator new[](std::size_t s) TEST_THROW_SPEC(std::bad_alloc)
377 {
378 getGlobalMemCounter()->newArrayCalled(s);
379 return operator new(s);
380 }
381
382 void operator delete[](void* p) TEST_NOEXCEPT
383 {
384 getGlobalMemCounter()->deleteArrayCalled(p);
385 operator delete(p);
386 }
387
388 #ifndef TEST_HAS_NO_ALIGNED_ALLOCATION
389 #if defined(_LIBCPP_MSVCRT_LIKE) || \
390 (!defined(_LIBCPP_VERSION) && defined(_WIN32))
391 #define USE_ALIGNED_ALLOC
392 #endif
393
new(std::size_t s,std::align_val_t av)394 void* operator new(std::size_t s, std::align_val_t av) TEST_THROW_SPEC(std::bad_alloc) {
395 const std::size_t a = static_cast<std::size_t>(av);
396 getGlobalMemCounter()->alignedNewCalled(s, a);
397 void *ret;
398 #ifdef USE_ALIGNED_ALLOC
399 ret = _aligned_malloc(s, a);
400 #else
401 posix_memalign(&ret, a, s);
402 #endif
403 if (ret == nullptr)
404 detail::throw_bad_alloc_helper();
405 return ret;
406 }
407
delete(void * p,std::align_val_t av)408 void operator delete(void *p, std::align_val_t av) TEST_NOEXCEPT {
409 const std::size_t a = static_cast<std::size_t>(av);
410 getGlobalMemCounter()->alignedDeleteCalled(p, a);
411 if (p) {
412 #ifdef USE_ALIGNED_ALLOC
413 ::_aligned_free(p);
414 #else
415 ::free(p);
416 #endif
417 }
418 }
419
TEST_THROW_SPEC(std::bad_alloc)420 void* operator new[](std::size_t s, std::align_val_t av) TEST_THROW_SPEC(std::bad_alloc) {
421 const std::size_t a = static_cast<std::size_t>(av);
422 getGlobalMemCounter()->alignedNewArrayCalled(s, a);
423 return operator new(s, av);
424 }
425
426 void operator delete[](void *p, std::align_val_t av) TEST_NOEXCEPT {
427 const std::size_t a = static_cast<std::size_t>(av);
428 getGlobalMemCounter()->alignedDeleteArrayCalled(p, a);
429 return operator delete(p, av);
430 }
431
432 #endif // TEST_HAS_NO_ALIGNED_ALLOCATION
433
434 #endif // DISABLE_NEW_COUNT
435
436 struct DisableAllocationGuard {
m_disabledDisableAllocationGuard437 explicit DisableAllocationGuard(bool disable = true) : m_disabled(disable)
438 {
439 // Don't re-disable if already disabled.
440 if (globalMemCounter.disable_allocations == true) m_disabled = false;
441 if (m_disabled) globalMemCounter.disableAllocations();
442 }
443
releaseDisableAllocationGuard444 void release() {
445 if (m_disabled) globalMemCounter.enableAllocations();
446 m_disabled = false;
447 }
448
~DisableAllocationGuardDisableAllocationGuard449 ~DisableAllocationGuard() {
450 release();
451 }
452
453 private:
454 bool m_disabled;
455
456 DisableAllocationGuard(DisableAllocationGuard const&);
457 DisableAllocationGuard& operator=(DisableAllocationGuard const&);
458 };
459
460 struct RequireAllocationGuard {
461 explicit RequireAllocationGuard(std::size_t RequireAtLeast = 1)
m_req_allocRequireAllocationGuard462 : m_req_alloc(RequireAtLeast),
463 m_new_count_on_init(globalMemCounter.new_called),
464 m_outstanding_new_on_init(globalMemCounter.outstanding_new),
465 m_exactly(false)
466 {
467 }
468
requireAtLeastRequireAllocationGuard469 void requireAtLeast(std::size_t N) { m_req_alloc = N; m_exactly = false; }
requireExactlyRequireAllocationGuard470 void requireExactly(std::size_t N) { m_req_alloc = N; m_exactly = true; }
471
~RequireAllocationGuardRequireAllocationGuard472 ~RequireAllocationGuard() {
473 assert(globalMemCounter.checkOutstandingNewEq(static_cast<int>(m_outstanding_new_on_init)));
474 std::size_t Expect = m_new_count_on_init + m_req_alloc;
475 assert(globalMemCounter.checkNewCalledEq(static_cast<int>(Expect)) ||
476 (!m_exactly && globalMemCounter.checkNewCalledGreaterThan(static_cast<int>(Expect))));
477 }
478
479 private:
480 std::size_t m_req_alloc;
481 const std::size_t m_new_count_on_init;
482 const std::size_t m_outstanding_new_on_init;
483 bool m_exactly;
484 RequireAllocationGuard(RequireAllocationGuard const&);
485 RequireAllocationGuard& operator=(RequireAllocationGuard const&);
486 };
487
488 #endif /* COUNT_NEW_H */
489