1 /* -*- Mode: C; tab-width: 4; c-basic-offset: 4; indent-tabs-mode: nil -*- */
2 
3 #include "proxy.h"
4 #include "proxy_tls.h"
5 #ifdef PROXY_TLS
6 #include <openssl/ssl.h>
7 #include <openssl/err.h>
8 
9 /* Notes on ERR_clear_error() and friends:
10  * - Errors from SSL calls leave errors on a thread-local "error stack"
11  * - If an error is received from an SSL call, the stack needs to be inspected
12  *   and cleared.
13  * - The error stack _must_ be clear before any SSL_get_error() calls, as it
14  *   may return garbage.
15  * - There may be _multiple_ errors queued after one SSL call, so just
16  *   checking the top level does not clear it.
17  * - ERR_clear_error() is not "free", so we would prefer to avoid calling it
18  *   before hotpath calls. Thus, we should ensure it's called _after_ any
19  *   hotpath call that receives any kind of error.
20  * - We should also call it _before_ any non-hotpath SSL calls (such as
21  *   SSL_connect()) for defense against bugs in our code or OpenSSL.
22  */
23 
mcp_tls_init(proxy_ctx_t * ctx)24 int mcp_tls_init(proxy_ctx_t *ctx) {
25     if (ctx->tls_ctx) {
26         return MCP_TLS_OK;
27     }
28 
29     // TODO: check for OpenSSL 1.1+ ? should be elsewhere in the code.
30     SSL_CTX *tctx = SSL_CTX_new(TLS_client_method());
31     if (tctx == NULL) {
32         return MCP_TLS_ERR;
33     }
34 
35     // TODO: make configurable like main cache server
36     SSL_CTX_set_min_proto_version(tctx, TLS1_3_VERSION);
37     // reduce memory consumption of idle backends.
38     SSL_CTX_set_mode(tctx, SSL_MODE_RELEASE_BUFFERS);
39 
40     ctx->tls_ctx = tctx;
41     return 0;
42 }
43 
mcp_tls_backend_init(proxy_ctx_t * ctx,struct mcp_backendconn_s * be)44 int mcp_tls_backend_init(proxy_ctx_t *ctx, struct mcp_backendconn_s *be) {
45     if (!be->be_parent->tunables.use_tls) {
46         return MCP_TLS_OK;
47     }
48 
49     SSL *ssl = SSL_new(ctx->tls_ctx);
50     if (ssl == NULL) {
51         return MCP_TLS_ERR;
52     }
53 
54     be->ssl = ssl;
55     // SSL_set_fd() will free a pre-existing BIO and allocate a new one
56     // so we set any file descriptor at connect time instead.
57 
58     return MCP_TLS_OK;
59 }
60 
mcp_tls_shutdown(struct mcp_backendconn_s * be)61 int mcp_tls_shutdown(struct mcp_backendconn_s *be) {
62     if (!be->ssl) {
63         return MCP_TLS_OK;
64     }
65 
66     // TODO: This may need to be called multiple times to "properly" shutdown
67     // a session. However we only ever call this when a backend is dead or not
68     // in used anymore. Unclear if checking for WANT_READ|WRITE is worth
69     // doing.
70     SSL_shutdown(be->ssl);
71 
72     return MCP_TLS_OK;
73 }
74 
mcp_tls_cleanup(struct mcp_backendconn_s * be)75 int mcp_tls_cleanup(struct mcp_backendconn_s *be) {
76     if (!be->ssl) {
77         return MCP_TLS_OK;
78     }
79 
80     SSL_free(be->ssl);
81     be->ssl = NULL;
82     return MCP_TLS_OK;
83 }
84 
85 // Contrary to the name of this function, the underlying tcp socket must
86 // already be connected.
mcp_tls_connect(struct mcp_backendconn_s * be)87 int mcp_tls_connect(struct mcp_backendconn_s *be) {
88     // TODO: check return code. can fail if BIO fails to alloc.
89     SSL_set_fd(be->ssl, mcmc_fd(be->client));
90 
91     // TODO:
92     // if the backend is changing TLS version or some similar issue, we will
93     // be unable to reconnect as the SSL object "Caches" some information
94     // about the previous request (why doesn't clear work then???)
95     // This will normally be fine, but we should detect severe errors here and
96     // decide if we should free and re-alloc the SSL object.
97     // Allocating the SSL object can be pretty slow, so we should at least
98     // attempt to do not do this.
99     // Related: https://github.com/openssl/openssl/issues/20286
100     SSL_clear(be->ssl);
101     ERR_clear_error();
102     int n = SSL_connect(be->ssl);
103     int ret = MCP_TLS_OK;
104     // TODO: complete error handling.
105     if (n == 1) {
106         // Successfully established and handshake complete.
107         return ret;
108     }
109 
110     int err = SSL_get_error(be->ssl, n);
111     if (n == 0) {
112         // Not successsful, but shut down normally.
113         ERR_clear_error();
114         ret = MCP_TLS_ERR;
115     } else if (n < 0) {
116         // Not successful. Check for temporary error.
117         if (err == SSL_ERROR_WANT_READ ||
118             err == SSL_ERROR_WANT_WRITE) {
119             ret = MCP_TLS_OK;
120         } else {
121             ret = MCP_TLS_ERR;
122         }
123         ERR_clear_error();
124     }
125 
126     return ret;
127 }
128 
mcp_tls_handshake(struct mcp_backendconn_s * be)129 int mcp_tls_handshake(struct mcp_backendconn_s *be) {
130     if (SSL_is_init_finished(be->ssl)) {
131         return MCP_TLS_OK;
132     }
133 
134     // Non hot path, so clear errors before running.
135     ERR_clear_error();
136     int n = SSL_do_handshake(be->ssl);
137     if (n == 1) {
138         return MCP_TLS_OK;
139     }
140 
141     int err = SSL_get_error(be->ssl, n);
142     // TODO: realistically we'll only ever get WANT_READ here, since OpenSSL
143     // is handling the fd and it will have written a small number of bytes.
144     // leaving this note just in case.
145     if (err == SSL_ERROR_WANT_READ ||
146         err == SSL_ERROR_WANT_WRITE) {
147         // So far as I can tell there would be an error on the queue here.
148         ERR_clear_error();
149         return MCP_TLS_NEEDIO;
150     } else {
151         // TODO: can get the full error message and give to the caller to log
152         // to proxyevents?
153         ERR_clear_error();
154         return MCP_TLS_ERR;
155     }
156 }
157 
mcp_tls_send_validate(struct mcp_backendconn_s * be)158 int mcp_tls_send_validate(struct mcp_backendconn_s *be) {
159     const char *str = "version\r\n";
160     const size_t len = strlen(str);
161 
162     // Non hot path, clear errors.
163     ERR_clear_error();
164     int n = SSL_write(be->ssl, str, len);
165 
166     // TODO: more detailed error checking.
167     if (n < 0 || n != len) {
168         ERR_clear_error();
169         return MCP_TLS_ERR;
170     }
171 
172     return MCP_TLS_OK;
173 }
174 
mcp_tls_read(struct mcp_backendconn_s * be)175 int mcp_tls_read(struct mcp_backendconn_s *be) {
176     int n = SSL_read(be->ssl, be->rbuf + be->rbufused, READ_BUFFER_SIZE - be->rbufused);
177 
178     if (n < 0) {
179         int err = SSL_get_error(be->ssl, n);
180         if (err == SSL_ERROR_WANT_WRITE ||
181             err == SSL_ERROR_WANT_READ) {
182             ERR_clear_error();
183             return MCP_TLS_NEEDIO;
184         } else {
185             // TODO: log detailed error.
186             ERR_clear_error();
187             return MCP_TLS_ERR;
188         }
189     } else {
190         be->rbufused += n;
191         return n;
192     }
193 
194     return 0;
195 }
196 
197 // TODO: option.
198 #define TLS_WBUF_SIZE 16 * 1024
199 
200 // We cache the temporary write buffer on the be's event thread.
201 // This is actually required when retrying ops (WANT_WRITE/etc) unless
202 // MOVING_BUFFERS flag is set in OpenSSL.
mcp_tls_writev(struct mcp_backendconn_s * be,int iovcnt)203 int mcp_tls_writev(struct mcp_backendconn_s *be, int iovcnt) {
204     proxy_event_thread_t *et = be->event_thread;
205     // TODO: move this to event thread init to remove branch and move error
206     // handling to startup time.
207     // Actually we won't know if TLS is in use until a backend shows up and
208     // tries to write... so I'm not sure where to move this. TLS compiled in
209     // but not used would waste memory.
210     // Maybe can at least mark it unlikely()?
211     if (et->tls_wbuf_size == 0) {
212         et->tls_wbuf_size = TLS_WBUF_SIZE;
213         et->tls_wbuf = malloc(et->tls_wbuf_size);
214         if (et->tls_wbuf == NULL) {
215             return MCP_TLS_ERR;
216         }
217     }
218     size_t remain = et->tls_wbuf_size;
219     char *b = et->tls_wbuf;
220 
221     // OpenSSL has no writev or TCP_CORK equivalent, so we have to pre-mempcy
222     // the iov's here.
223     for (int i = 0; i < iovcnt; i++) {
224         size_t len = be->write_iovs[i].iov_len;
225         size_t to_copy = len < remain ? len : remain;
226 
227         memcpy(b, (char *)be->write_iovs[i].iov_base, to_copy);
228         remain -= to_copy;
229         b += to_copy;
230         if (remain == 0)
231             break;
232     }
233 
234     int n = SSL_write(be->ssl, et->tls_wbuf, b - et->tls_wbuf);
235     if (n < 0) {
236         int err = SSL_get_error(be->ssl, n);
237         if (err == SSL_ERROR_WANT_WRITE ||
238             err == SSL_ERROR_WANT_READ) {
239             ERR_clear_error();
240             return MCP_TLS_NEEDIO;
241         }
242         ERR_clear_error();
243         return MCP_TLS_ERR;
244     }
245 
246     return n;
247 }
248 
249 #else // PROXY_TLS
250 
mcp_tls_writev(struct mcp_backendconn_s * be,int iovcnt)251 int mcp_tls_writev(struct mcp_backendconn_s *be, int iovcnt) {
252     (void)be;
253     (void)iovcnt;
254     return 0;
255 }
256 
257 #endif // PROXY_TLS
258