OpenSSL: Add generic direction enum for negotiation setup.

Also define an SSL_P(x) macro as in other backends and use that to refer
to the client session context.
This commit is contained in:
Aaron Jones 2016-09-15 17:06:05 +00:00
parent e569720fe1
commit 4e9441a1cb
No known key found for this signature in database
GPG key ID: EC6F86EE9CD840B5

View file

@ -33,6 +33,16 @@
#include "openssl_ratbox.h" #include "openssl_ratbox.h"
typedef enum
{
RB_FD_TLS_DIRECTION_IN = 0,
RB_FD_TLS_DIRECTION_OUT = 1
} rb_fd_tls_direction;
#define SSL_P(x) ((SSL *)((x)->ssl))
static SSL_CTX *ssl_ctx = NULL; static SSL_CTX *ssl_ctx = NULL;
struct ssl_connect struct ssl_connect
@ -46,6 +56,7 @@ static void rb_ssl_connect_realcb(rb_fde_t *, int, struct ssl_connect *);
static void rb_ssl_tryconn_timeout_cb(rb_fde_t *, void *); static void rb_ssl_tryconn_timeout_cb(rb_fde_t *, void *);
static void rb_ssl_timeout(rb_fde_t *, void *); static void rb_ssl_timeout(rb_fde_t *, void *);
static void rb_ssl_tryaccept(rb_fde_t *, void *); static void rb_ssl_tryaccept(rb_fde_t *, void *);
static const char *get_ssl_error(unsigned long);
@ -64,14 +75,41 @@ get_last_err(void)
return err; return err;
} }
static void
rb_ssl_init_fd(rb_fde_t *const F, const rb_fd_tls_direction dir)
{
(void) get_last_err();
F->ssl = SSL_new(ssl_ctx);
if(F->ssl == NULL)
{
rb_lib_log("%s: SSL_new: %s", __func__, get_ssl_error(get_last_err()));
rb_close(F);
return;
}
switch(dir)
{
case RB_FD_TLS_DIRECTION_IN:
SSL_set_accept_state(SSL_P(F));
break;
case RB_FD_TLS_DIRECTION_OUT:
SSL_set_connect_state(SSL_P(F));
break;
}
SSL_set_fd(SSL_P(F), rb_get_fd(F));
}
static void static void
rb_ssl_accept_common(rb_fde_t *const new_F) rb_ssl_accept_common(rb_fde_t *const new_F)
{ {
int ssl_err = SSL_accept((SSL *) new_F->ssl); int ssl_err = SSL_accept(SSL_P(new_F));
if(ssl_err <= 0) if(ssl_err <= 0)
{ {
switch((ssl_err = SSL_get_error((SSL *) new_F->ssl, ssl_err))) switch((ssl_err = SSL_get_error(SSL_P(new_F), ssl_err)))
{ {
case SSL_ERROR_SYSCALL: case SSL_ERROR_SYSCALL:
if(rb_ignore_errno(errno)) if(rb_ignore_errno(errno))
@ -101,14 +139,14 @@ rb_ssl_tryaccept(rb_fde_t *const F, void *const data)
{ {
lrb_assert(F->accept != NULL); lrb_assert(F->accept != NULL);
if(! SSL_is_init_finished((SSL *) F->ssl)) if(! SSL_is_init_finished(SSL_P(F)))
{ {
int flags; int flags;
int ssl_err = SSL_accept((SSL *) F->ssl); int ssl_err = SSL_accept(SSL_P(F));
if(ssl_err <= 0) if(ssl_err <= 0)
{ {
switch(ssl_err = SSL_get_error((SSL *) F->ssl, ssl_err)) switch(ssl_err = SSL_get_error(SSL_P(F), ssl_err))
{ {
case SSL_ERROR_WANT_READ: case SSL_ERROR_WANT_READ:
case SSL_ERROR_WANT_WRITE: case SSL_ERROR_WANT_WRITE:
@ -145,14 +183,14 @@ rb_ssl_tryaccept(rb_fde_t *const F, void *const data)
static void static void
rb_ssl_tryconn_cb(rb_fde_t *const F, void *const data) rb_ssl_tryconn_cb(rb_fde_t *const F, void *const data)
{ {
if(! SSL_is_init_finished((SSL *) F->ssl)) if(! SSL_is_init_finished(SSL_P(F)))
{ {
struct ssl_connect *const sconn = data; struct ssl_connect *const sconn = data;
int ssl_err = SSL_connect((SSL *) F->ssl); int ssl_err = SSL_connect(SSL_P(F));
if(ssl_err <= 0) if(ssl_err <= 0)
{ {
switch(ssl_err = SSL_get_error((SSL *) F->ssl, ssl_err)) switch(ssl_err = SSL_get_error(SSL_P(F), ssl_err))
{ {
case SSL_ERROR_SYSCALL: case SSL_ERROR_SYSCALL:
if(rb_ignore_errno(errno)) if(rb_ignore_errno(errno))
@ -181,6 +219,8 @@ rb_ssl_tryconn_cb(rb_fde_t *const F, void *const data)
static void static void
rb_ssl_tryconn(rb_fde_t *const F, const int status, void *const data) rb_ssl_tryconn(rb_fde_t *const F, const int status, void *const data)
{ {
lrb_assert(F != NULL);
struct ssl_connect *const sconn = data; struct ssl_connect *const sconn = data;
if(status != RB_OK) if(status != RB_OK)
@ -191,17 +231,14 @@ rb_ssl_tryconn(rb_fde_t *const F, const int status, void *const data)
F->type |= RB_FD_SSL; F->type |= RB_FD_SSL;
F->ssl = SSL_new(ssl_ctx);
SSL_set_fd((SSL *) F->ssl, F->fd);
SSL_set_connect_state((SSL *) F->ssl);
rb_settimeout(F, sconn->timeout, rb_ssl_tryconn_timeout_cb, sconn); rb_settimeout(F, sconn->timeout, rb_ssl_tryconn_timeout_cb, sconn);
rb_ssl_init_fd(F, RB_FD_TLS_DIRECTION_OUT);
int ssl_err = SSL_connect((SSL *) F->ssl); int ssl_err = SSL_connect(SSL_P(F));
if(ssl_err <= 0) if(ssl_err <= 0)
{ {
switch(ssl_err = SSL_get_error((SSL *) F->ssl, ssl_err)) switch(ssl_err = SSL_get_error(SSL_P(F), ssl_err))
{ {
case SSL_ERROR_SYSCALL: case SSL_ERROR_SYSCALL:
if(rb_ignore_errno(errno)) if(rb_ignore_errno(errno))
@ -247,16 +284,15 @@ rb_ssl_read_or_write(const int r_or_w, rb_fde_t *const F, void *const rbuf, cons
{ {
ssize_t ret; ssize_t ret;
unsigned long err; unsigned long err;
SSL *const ssl = F->ssl;
if(r_or_w == 0) if(r_or_w == 0)
ret = (ssize_t) SSL_read(ssl, rbuf, (int)count); ret = (ssize_t) SSL_read(SSL_P(F), rbuf, (int)count);
else else
ret = (ssize_t) SSL_write(ssl, wbuf, (int)count); ret = (ssize_t) SSL_write(SSL_P(F), wbuf, (int)count);
if(ret < 0) if(ret < 0)
{ {
switch(SSL_get_error(ssl, ret)) switch(SSL_get_error(SSL_P(F), ret))
{ {
case SSL_ERROR_WANT_READ: case SSL_ERROR_WANT_READ:
errno = EAGAIN; errno = EAGAIN;
@ -302,16 +338,17 @@ rb_ssl_shutdown(rb_fde_t *const F)
if(F == NULL || F->ssl == NULL) if(F == NULL || F->ssl == NULL)
return; return;
SSL_set_shutdown((SSL *) F->ssl, SSL_RECEIVED_SHUTDOWN); SSL_set_shutdown(SSL_P(F), SSL_RECEIVED_SHUTDOWN);
for(int i = 0; i < 4; i++) for(int i = 0; i < 4; i++)
{ {
if(SSL_shutdown((SSL *) F->ssl)) if(SSL_shutdown(SSL_P(F)))
break; break;
} }
get_last_err(); get_last_err();
SSL_free((SSL *) F->ssl); SSL_free(SSL_P(F));
F->ssl = NULL;
} }
int int
@ -531,11 +568,11 @@ rb_get_ssl_certfp(rb_fde_t *const F, uint8_t certfp[const RB_SSL_CERTFP_LEN], co
return 0; return 0;
} }
X509 *const cert = SSL_get_peer_certificate((SSL *) F->ssl); X509 *const cert = SSL_get_peer_certificate(SSL_P(F));
if(cert == NULL) if(cert == NULL)
return 0; return 0;
int res = SSL_get_verify_result((SSL *) F->ssl); int res = SSL_get_verify_result(SSL_P(F));
switch(res) switch(res)
{ {
case X509_V_OK: case X509_V_OK:
@ -581,8 +618,8 @@ rb_ssl_get_cipher(rb_fde_t *const F)
static char buf[512]; static char buf[512];
const char *const version = SSL_get_version(F->ssl); const char *const version = SSL_get_version(SSL_P(F));
const char *const cipher = SSL_get_cipher_name(F->ssl); const char *const cipher = SSL_get_cipher_name(SSL_P(F));
(void) rb_snprintf(buf, sizeof buf, "%s, %s", version, cipher); (void) rb_snprintf(buf, sizeof buf, "%s, %s", version, cipher);
@ -613,9 +650,7 @@ rb_ssl_start_accepted(rb_fde_t *const F, ACCB *const cb, void *const data, const
(void) memset(&F->accept->S, 0x00, sizeof F->accept->S); (void) memset(&F->accept->S, 0x00, sizeof F->accept->S);
rb_settimeout(F, timeout, rb_ssl_timeout, NULL); rb_settimeout(F, timeout, rb_ssl_timeout, NULL);
F->ssl = SSL_new(ssl_ctx); rb_ssl_init_fd(F, RB_FD_TLS_DIRECTION_IN);
SSL_set_fd((SSL *) F->ssl, rb_get_fd(F));
SSL_set_accept_state((SSL *) F->ssl);
rb_ssl_accept_common(F); rb_ssl_accept_common(F);
} }
@ -632,9 +667,7 @@ rb_ssl_accept_setup(rb_fde_t *const srv_F, rb_fde_t *const cli_F, struct sockadd
(void) memcpy(&cli_F->accept->S, st, addrlen); (void) memcpy(&cli_F->accept->S, st, addrlen);
rb_settimeout(cli_F, 10, rb_ssl_timeout, NULL); rb_settimeout(cli_F, 10, rb_ssl_timeout, NULL);
cli_F->ssl = SSL_new(ssl_ctx); rb_ssl_init_fd(cli_F, RB_FD_TLS_DIRECTION_IN);
SSL_set_fd((SSL *) cli_F->ssl, rb_get_fd(cli_F));
SSL_set_accept_state((SSL *) cli_F->ssl);
rb_ssl_accept_common(cli_F); rb_ssl_accept_common(cli_F);
} }
@ -655,15 +688,13 @@ rb_ssl_start_connected(rb_fde_t *const F, CNCB *const callback, void *const data
F->type |= RB_FD_SSL; F->type |= RB_FD_SSL;
rb_settimeout(F, sconn->timeout, rb_ssl_tryconn_timeout_cb, sconn); rb_settimeout(F, sconn->timeout, rb_ssl_tryconn_timeout_cb, sconn);
F->ssl = SSL_new(ssl_ctx); rb_ssl_init_fd(F, RB_FD_TLS_DIRECTION_OUT);
SSL_set_fd((SSL *) F->ssl, F->fd);
SSL_set_connect_state((SSL *) F->ssl);
int ssl_err = SSL_connect((SSL *) F->ssl); int ssl_err = SSL_connect(SSL_P(F));
if(ssl_err <= 0) if(ssl_err <= 0)
{ {
switch(ssl_err = SSL_get_error((SSL *) F->ssl, ssl_err)) switch((ssl_err = SSL_get_error(SSL_P(F), ssl_err)))
{ {
case SSL_ERROR_SYSCALL: case SSL_ERROR_SYSCALL:
if(rb_ignore_errno(errno)) if(rb_ignore_errno(errno))