diff --git a/libratbox/src/openssl.c b/libratbox/src/openssl.c index 3ae97108..e3f199d8 100644 --- a/libratbox/src/openssl.c +++ b/libratbox/src/openssl.c @@ -33,6 +33,16 @@ #include "openssl_ratbox.h" +typedef enum +{ + RB_FD_TLS_DIRECTION_IN = 0, + RB_FD_TLS_DIRECTION_OUT = 1 +} rb_fd_tls_direction; + +#define SSL_P(x) ((SSL *)((x)->ssl)) + + + static SSL_CTX *ssl_ctx = NULL; struct ssl_connect @@ -46,6 +56,7 @@ static void rb_ssl_connect_realcb(rb_fde_t *, int, struct ssl_connect *); static void rb_ssl_tryconn_timeout_cb(rb_fde_t *, void *); static void rb_ssl_timeout(rb_fde_t *, void *); static void rb_ssl_tryaccept(rb_fde_t *, void *); +static const char *get_ssl_error(unsigned long); @@ -64,14 +75,41 @@ get_last_err(void) return err; } +static void +rb_ssl_init_fd(rb_fde_t *const F, const rb_fd_tls_direction dir) +{ + (void) get_last_err(); + + F->ssl = SSL_new(ssl_ctx); + + if(F->ssl == NULL) + { + rb_lib_log("%s: SSL_new: %s", __func__, get_ssl_error(get_last_err())); + rb_close(F); + return; + } + + switch(dir) + { + case RB_FD_TLS_DIRECTION_IN: + SSL_set_accept_state(SSL_P(F)); + break; + case RB_FD_TLS_DIRECTION_OUT: + SSL_set_connect_state(SSL_P(F)); + break; + } + + SSL_set_fd(SSL_P(F), rb_get_fd(F)); +} + static void rb_ssl_accept_common(rb_fde_t *const new_F) { - int ssl_err = SSL_accept((SSL *) new_F->ssl); + int ssl_err = SSL_accept(SSL_P(new_F)); if(ssl_err <= 0) { - switch((ssl_err = SSL_get_error((SSL *) new_F->ssl, ssl_err))) + switch((ssl_err = SSL_get_error(SSL_P(new_F), ssl_err))) { case SSL_ERROR_SYSCALL: if(rb_ignore_errno(errno)) @@ -101,14 +139,14 @@ rb_ssl_tryaccept(rb_fde_t *const F, void *const data) { lrb_assert(F->accept != NULL); - if(! SSL_is_init_finished((SSL *) F->ssl)) + if(! SSL_is_init_finished(SSL_P(F))) { int flags; - int ssl_err = SSL_accept((SSL *) F->ssl); + int ssl_err = SSL_accept(SSL_P(F)); if(ssl_err <= 0) { - switch(ssl_err = SSL_get_error((SSL *) F->ssl, ssl_err)) + switch(ssl_err = SSL_get_error(SSL_P(F), ssl_err)) { case SSL_ERROR_WANT_READ: case SSL_ERROR_WANT_WRITE: @@ -145,14 +183,14 @@ rb_ssl_tryaccept(rb_fde_t *const F, void *const data) static void rb_ssl_tryconn_cb(rb_fde_t *const F, void *const data) { - if(! SSL_is_init_finished((SSL *) F->ssl)) + if(! SSL_is_init_finished(SSL_P(F))) { struct ssl_connect *const sconn = data; - int ssl_err = SSL_connect((SSL *) F->ssl); + int ssl_err = SSL_connect(SSL_P(F)); if(ssl_err <= 0) { - switch(ssl_err = SSL_get_error((SSL *) F->ssl, ssl_err)) + switch(ssl_err = SSL_get_error(SSL_P(F), ssl_err)) { case SSL_ERROR_SYSCALL: if(rb_ignore_errno(errno)) @@ -181,6 +219,8 @@ rb_ssl_tryconn_cb(rb_fde_t *const F, void *const data) static void rb_ssl_tryconn(rb_fde_t *const F, const int status, void *const data) { + lrb_assert(F != NULL); + struct ssl_connect *const sconn = data; if(status != RB_OK) @@ -191,17 +231,14 @@ rb_ssl_tryconn(rb_fde_t *const F, const int status, void *const data) F->type |= RB_FD_SSL; - F->ssl = SSL_new(ssl_ctx); - SSL_set_fd((SSL *) F->ssl, F->fd); - SSL_set_connect_state((SSL *) F->ssl); - rb_settimeout(F, sconn->timeout, rb_ssl_tryconn_timeout_cb, sconn); + rb_ssl_init_fd(F, RB_FD_TLS_DIRECTION_OUT); - int ssl_err = SSL_connect((SSL *) F->ssl); + int ssl_err = SSL_connect(SSL_P(F)); if(ssl_err <= 0) { - switch(ssl_err = SSL_get_error((SSL *) F->ssl, ssl_err)) + switch(ssl_err = SSL_get_error(SSL_P(F), ssl_err)) { case SSL_ERROR_SYSCALL: if(rb_ignore_errno(errno)) @@ -247,16 +284,15 @@ rb_ssl_read_or_write(const int r_or_w, rb_fde_t *const F, void *const rbuf, cons { ssize_t ret; unsigned long err; - SSL *const ssl = F->ssl; if(r_or_w == 0) - ret = (ssize_t) SSL_read(ssl, rbuf, (int)count); + ret = (ssize_t) SSL_read(SSL_P(F), rbuf, (int)count); else - ret = (ssize_t) SSL_write(ssl, wbuf, (int)count); + ret = (ssize_t) SSL_write(SSL_P(F), wbuf, (int)count); if(ret < 0) { - switch(SSL_get_error(ssl, ret)) + switch(SSL_get_error(SSL_P(F), ret)) { case SSL_ERROR_WANT_READ: errno = EAGAIN; @@ -302,16 +338,17 @@ rb_ssl_shutdown(rb_fde_t *const F) if(F == NULL || F->ssl == NULL) return; - SSL_set_shutdown((SSL *) F->ssl, SSL_RECEIVED_SHUTDOWN); + SSL_set_shutdown(SSL_P(F), SSL_RECEIVED_SHUTDOWN); for(int i = 0; i < 4; i++) { - if(SSL_shutdown((SSL *) F->ssl)) + if(SSL_shutdown(SSL_P(F))) break; } get_last_err(); - SSL_free((SSL *) F->ssl); + SSL_free(SSL_P(F)); + F->ssl = NULL; } int @@ -531,11 +568,11 @@ rb_get_ssl_certfp(rb_fde_t *const F, uint8_t certfp[const RB_SSL_CERTFP_LEN], co return 0; } - X509 *const cert = SSL_get_peer_certificate((SSL *) F->ssl); + X509 *const cert = SSL_get_peer_certificate(SSL_P(F)); if(cert == NULL) return 0; - int res = SSL_get_verify_result((SSL *) F->ssl); + int res = SSL_get_verify_result(SSL_P(F)); switch(res) { case X509_V_OK: @@ -581,8 +618,8 @@ rb_ssl_get_cipher(rb_fde_t *const F) static char buf[512]; - const char *const version = SSL_get_version(F->ssl); - const char *const cipher = SSL_get_cipher_name(F->ssl); + const char *const version = SSL_get_version(SSL_P(F)); + const char *const cipher = SSL_get_cipher_name(SSL_P(F)); (void) rb_snprintf(buf, sizeof buf, "%s, %s", version, cipher); @@ -613,9 +650,7 @@ rb_ssl_start_accepted(rb_fde_t *const F, ACCB *const cb, void *const data, const (void) memset(&F->accept->S, 0x00, sizeof F->accept->S); rb_settimeout(F, timeout, rb_ssl_timeout, NULL); - F->ssl = SSL_new(ssl_ctx); - SSL_set_fd((SSL *) F->ssl, rb_get_fd(F)); - SSL_set_accept_state((SSL *) F->ssl); + rb_ssl_init_fd(F, RB_FD_TLS_DIRECTION_IN); rb_ssl_accept_common(F); } @@ -632,9 +667,7 @@ rb_ssl_accept_setup(rb_fde_t *const srv_F, rb_fde_t *const cli_F, struct sockadd (void) memcpy(&cli_F->accept->S, st, addrlen); rb_settimeout(cli_F, 10, rb_ssl_timeout, NULL); - cli_F->ssl = SSL_new(ssl_ctx); - SSL_set_fd((SSL *) cli_F->ssl, rb_get_fd(cli_F)); - SSL_set_accept_state((SSL *) cli_F->ssl); + rb_ssl_init_fd(cli_F, RB_FD_TLS_DIRECTION_IN); rb_ssl_accept_common(cli_F); } @@ -655,15 +688,13 @@ rb_ssl_start_connected(rb_fde_t *const F, CNCB *const callback, void *const data F->type |= RB_FD_SSL; rb_settimeout(F, sconn->timeout, rb_ssl_tryconn_timeout_cb, sconn); - F->ssl = SSL_new(ssl_ctx); - SSL_set_fd((SSL *) F->ssl, F->fd); - SSL_set_connect_state((SSL *) F->ssl); + rb_ssl_init_fd(F, RB_FD_TLS_DIRECTION_OUT); - int ssl_err = SSL_connect((SSL *) F->ssl); + int ssl_err = SSL_connect(SSL_P(F)); if(ssl_err <= 0) { - switch(ssl_err = SSL_get_error((SSL *) F->ssl, ssl_err)) + switch((ssl_err = SSL_get_error(SSL_P(F), ssl_err))) { case SSL_ERROR_SYSCALL: if(rb_ignore_errno(errno))