1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 // FIXME: In MSVC mode, even "std::function<int(int)> f(aref);" causes
10 // allocations.
11 // XFAIL: target=x86_64-pc-windows-msvc && stdlib=libc++
12 
13 // <functional>
14 
15 // class function<R(ArgTypes...)>
16 
17 // function(const function&  f);
18 // function(function&& f); // noexcept in C++20
19 
20 // This test runs in C++03, but we have deprecated using std::function in C++03.
21 // ADDITIONAL_COMPILE_FLAGS: -D_LIBCPP_DISABLE_DEPRECATION_WARNINGS -D_LIBCPP_ENABLE_CXX03_FUNCTION
22 
23 #include <functional>
24 #include <memory>
25 #include <cstdlib>
26 #include <cassert>
27 
28 #include "test_macros.h"
29 #include "count_new.h"
30 
31 class A
32 {
33     int data_[10];
34 public:
35     static int count;
36 
37     A()
38     {
39         ++count;
40         for (int i = 0; i < 10; ++i)
41             data_[i] = i;
42     }
43 
44     A(const A&) {++count;}
45 
46     ~A() {--count;}
47 
48     int operator()(int i) const
49     {
50         for (int j = 0; j < 10; ++j)
51             i += data_[j];
52         return i;
53     }
54 };
55 
56 int A::count = 0;
57 
58 int g(int) {return 0;}
59 
60 int main(int, char**)
61 {
62     globalMemCounter.reset();
63     assert(globalMemCounter.checkOutstandingNewEq(0));
64     {
65     std::function<int(int)> f = A();
66     assert(A::count == 1);
67     assert(globalMemCounter.checkOutstandingNewEq(1));
68     RTTI_ASSERT(f.target<A>());
69     RTTI_ASSERT(f.target<int(*)(int)>() == 0);
70     std::function<int(int)> f2 = f;
71     assert(A::count == 2);
72     assert(globalMemCounter.checkOutstandingNewEq(2));
73     RTTI_ASSERT(f2.target<A>());
74     RTTI_ASSERT(f2.target<int(*)(int)>() == 0);
75     }
76     assert(A::count == 0);
77     assert(globalMemCounter.checkOutstandingNewEq(0));
78     {
79     std::function<int(int)> f = g;
80     assert(globalMemCounter.checkOutstandingNewEq(0));
81     RTTI_ASSERT(f.target<int(*)(int)>());
82     RTTI_ASSERT(f.target<A>() == 0);
83     std::function<int(int)> f2 = f;
84     assert(globalMemCounter.checkOutstandingNewEq(0));
85     RTTI_ASSERT(f2.target<int(*)(int)>());
86     RTTI_ASSERT(f2.target<A>() == 0);
87     }
88     assert(globalMemCounter.checkOutstandingNewEq(0));
89     {
90     std::function<int(int)> f;
91     assert(globalMemCounter.checkOutstandingNewEq(0));
92     RTTI_ASSERT(f.target<int(*)(int)>() == 0);
93     RTTI_ASSERT(f.target<A>() == 0);
94     std::function<int(int)> f2 = f;
95     assert(globalMemCounter.checkOutstandingNewEq(0));
96     RTTI_ASSERT(f2.target<int(*)(int)>() == 0);
97     RTTI_ASSERT(f2.target<A>() == 0);
98     }
99     {
100     std::function<int(int)> f;
101     assert(globalMemCounter.checkOutstandingNewEq(0));
102     RTTI_ASSERT(f.target<int(*)(int)>() == 0);
103     RTTI_ASSERT(f.target<A>() == 0);
104     assert(!f);
105     std::function<long(int)> g = f;
106     assert(globalMemCounter.checkOutstandingNewEq(0));
107     RTTI_ASSERT(g.target<long(*)(int)>() == 0);
108     RTTI_ASSERT(g.target<A>() == 0);
109     assert(!g);
110     }
111 #if TEST_STD_VER >= 11
112     assert(globalMemCounter.checkOutstandingNewEq(0));
113     { // Test rvalue references
114         std::function<int(int)> f = A();
115         assert(A::count == 1);
116         assert(globalMemCounter.checkOutstandingNewEq(1));
117         RTTI_ASSERT(f.target<A>());
118         RTTI_ASSERT(f.target<int(*)(int)>() == 0);
119 		LIBCPP_ASSERT_NOEXCEPT(std::function<int(int)>(std::move(f)));
120 #if TEST_STD_VER > 17
121 		ASSERT_NOEXCEPT(std::function<int(int)>(std::move(f)));
122 #endif
123         std::function<int(int)> f2 = std::move(f);
124         assert(A::count == 1);
125         assert(globalMemCounter.checkOutstandingNewEq(1));
126         RTTI_ASSERT(f2.target<A>());
127         RTTI_ASSERT(f2.target<int(*)(int)>() == 0);
128         RTTI_ASSERT(f.target<A>() == 0);
129         RTTI_ASSERT(f.target<int(*)(int)>() == 0);
130     }
131     assert(globalMemCounter.checkOutstandingNewEq(0));
132     {
133         // Test that moving a function constructed from a reference wrapper
134         // is done without allocating.
135         DisableAllocationGuard g;
136         using Ref = std::reference_wrapper<A>;
137         A a;
138         Ref aref(a);
139         std::function<int(int)> f(aref);
140         assert(A::count == 1);
141         RTTI_ASSERT(f.target<A>() == nullptr);
142         RTTI_ASSERT(f.target<Ref>());
143 		LIBCPP_ASSERT_NOEXCEPT(std::function<int(int)>(std::move(f)));
144 #if TEST_STD_VER > 17
145 		ASSERT_NOEXCEPT(std::function<int(int)>(std::move(f)));
146 #endif
147         std::function<int(int)> f2(std::move(f));
148         assert(A::count == 1);
149         RTTI_ASSERT(f2.target<A>() == nullptr);
150         RTTI_ASSERT(f2.target<Ref>());
151 #if defined(_LIBCPP_VERSION)
152         RTTI_ASSERT(f.target<Ref>()); // f is unchanged because the target is small
153 #endif
154     }
155     {
156         // Test that moving a function constructed from a function pointer
157         // is done without allocating
158         DisableAllocationGuard guard;
159         using Ptr = int(*)(int);
160         Ptr p = g;
161         std::function<int(int)> f(p);
162         RTTI_ASSERT(f.target<A>() == nullptr);
163         RTTI_ASSERT(f.target<Ptr>());
164 		LIBCPP_ASSERT_NOEXCEPT(std::function<int(int)>(std::move(f)));
165 #if TEST_STD_VER > 17
166 		ASSERT_NOEXCEPT(std::function<int(int)>(std::move(f)));
167 #endif
168         std::function<int(int)> f2(std::move(f));
169         RTTI_ASSERT(f2.target<A>() == nullptr);
170         RTTI_ASSERT(f2.target<Ptr>());
171 #if defined(_LIBCPP_VERSION)
172         RTTI_ASSERT(f.target<Ptr>()); // f is unchanged because the target is small
173 #endif
174     }
175 #endif // TEST_STD_VER >= 11
176 
177   return 0;
178 }
179