wsockd: implement websocket handshake part

This commit is contained in:
William Pitcock 2016-04-02 17:05:40 -05:00
parent 64fae2607a
commit 1160f6c9bf
4 changed files with 295 additions and 10 deletions

View file

@ -3,5 +3,5 @@ AM_CFLAGS=$(WARNFLAGS)
AM_CPPFLAGS = -I../include -I../librb/include
wsockd_SOURCES = wsockd.c
wsockd_SOURCES = wsockd.c sha1.c
wsockd_LDADD = ../librb/src/librb.la

143
wsockd/sha1.c Normal file
View file

@ -0,0 +1,143 @@
/*
* Based on the SHA-1 C implementation by Steve Reid <steve@edmweb.com>
* 100% Public Domain
*
* Test Vectors (from FIPS PUB 180-1)
* "abc"
* A9993E36 4706816A BA3E2571 7850C26C 9CD0D89D
* "abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq"
* 84983E44 1C3BD26E BAAE4AA1 F95129E5 E54670F1
* A million repetitions of "a"
* 34AA973C D4C4DAA4 F61EEB2B DBAD2731 6534016F
*/
#include <string.h>
#ifdef _WIN32
#include <winsock2.h> // for htonl()
#else
#include <netinet/in.h> // for htonl()
#endif
#include "sha1.h"
#define rol(value, bits) (((value) << (bits)) | ((value) >> (32 - (bits))))
// blk0() and blk() perform the initial expand. blk0() deals with host endianess
#define blk0(i) (block[i] = htonl(block[i]))
#define blk(i) (block[i&15] = rol(block[(i+13)&15]^block[(i+8)&15]^block[(i+2)&15]^block[i&15],1))
// (R0+R1), R2, R3, R4 are the different operations (rounds) used in SHA1
#define R0(v,w,x,y,z,i) z+=((w&(x^y))^y)+blk0(i)+0x5A827999+rol(v,5);w=rol(w,30);
#define R1(v,w,x,y,z,i) z+=((w&(x^y))^y)+blk(i)+0x5A827999+rol(v,5);w=rol(w,30);
#define R2(v,w,x,y,z,i) z+=(w^x^y)+blk(i)+0x6ED9EBA1+rol(v,5);w=rol(w,30);
#define R3(v,w,x,y,z,i) z+=(((w|x)&y)|(w&x))+blk(i)+0x8F1BBCDC+rol(v,5);w=rol(w,30);
#define R4(v,w,x,y,z,i) z+=(w^x^y)+blk(i)+0xCA62C1D6+rol(v,5);w=rol(w,30);
// hash a single 512-bit block. this is the core of the algorithm
static uint32_t sha1_transform(SHA1 *sha1, const uint8_t buffer[SHA1_BLOCK_LENGTH]) {
uint32_t a, b, c, d, e;
uint32_t block[SHA1_BLOCK_LENGTH / 4];
memcpy(&block, buffer, SHA1_BLOCK_LENGTH);
// copy sha1->state[] to working variables
a = sha1->state[0];
b = sha1->state[1];
c = sha1->state[2];
d = sha1->state[3];
e = sha1->state[4];
// 4 rounds of 20 operations each (loop unrolled)
R0(a,b,c,d,e, 0); R0(e,a,b,c,d, 1); R0(d,e,a,b,c, 2); R0(c,d,e,a,b, 3);
R0(b,c,d,e,a, 4); R0(a,b,c,d,e, 5); R0(e,a,b,c,d, 6); R0(d,e,a,b,c, 7);
R0(c,d,e,a,b, 8); R0(b,c,d,e,a, 9); R0(a,b,c,d,e,10); R0(e,a,b,c,d,11);
R0(d,e,a,b,c,12); R0(c,d,e,a,b,13); R0(b,c,d,e,a,14); R0(a,b,c,d,e,15);
R1(e,a,b,c,d,16); R1(d,e,a,b,c,17); R1(c,d,e,a,b,18); R1(b,c,d,e,a,19);
R2(a,b,c,d,e,20); R2(e,a,b,c,d,21); R2(d,e,a,b,c,22); R2(c,d,e,a,b,23);
R2(b,c,d,e,a,24); R2(a,b,c,d,e,25); R2(e,a,b,c,d,26); R2(d,e,a,b,c,27);
R2(c,d,e,a,b,28); R2(b,c,d,e,a,29); R2(a,b,c,d,e,30); R2(e,a,b,c,d,31);
R2(d,e,a,b,c,32); R2(c,d,e,a,b,33); R2(b,c,d,e,a,34); R2(a,b,c,d,e,35);
R2(e,a,b,c,d,36); R2(d,e,a,b,c,37); R2(c,d,e,a,b,38); R2(b,c,d,e,a,39);
R3(a,b,c,d,e,40); R3(e,a,b,c,d,41); R3(d,e,a,b,c,42); R3(c,d,e,a,b,43);
R3(b,c,d,e,a,44); R3(a,b,c,d,e,45); R3(e,a,b,c,d,46); R3(d,e,a,b,c,47);
R3(c,d,e,a,b,48); R3(b,c,d,e,a,49); R3(a,b,c,d,e,50); R3(e,a,b,c,d,51);
R3(d,e,a,b,c,52); R3(c,d,e,a,b,53); R3(b,c,d,e,a,54); R3(a,b,c,d,e,55);
R3(e,a,b,c,d,56); R3(d,e,a,b,c,57); R3(c,d,e,a,b,58); R3(b,c,d,e,a,59);
R4(a,b,c,d,e,60); R4(e,a,b,c,d,61); R4(d,e,a,b,c,62); R4(c,d,e,a,b,63);
R4(b,c,d,e,a,64); R4(a,b,c,d,e,65); R4(e,a,b,c,d,66); R4(d,e,a,b,c,67);
R4(c,d,e,a,b,68); R4(b,c,d,e,a,69); R4(a,b,c,d,e,70); R4(e,a,b,c,d,71);
R4(d,e,a,b,c,72); R4(c,d,e,a,b,73); R4(b,c,d,e,a,74); R4(a,b,c,d,e,75);
R4(e,a,b,c,d,76); R4(d,e,a,b,c,77); R4(c,d,e,a,b,78); R4(b,c,d,e,a,79);
// add the working variables back into sha1->state[]
sha1->state[0] += a;
sha1->state[1] += b;
sha1->state[2] += c;
sha1->state[3] += d;
sha1->state[4] += e;
// wipe variables
a = b = c = d = e = 0;
return a; // return a to avoid dead-store warning from clang static analyzer
}
void sha1_init(SHA1 *sha1) {
sha1->state[0] = 0x67452301;
sha1->state[1] = 0xEFCDAB89;
sha1->state[2] = 0x98BADCFE;
sha1->state[3] = 0x10325476;
sha1->state[4] = 0xC3D2E1F0;
sha1->count = 0;
}
void sha1_update(SHA1 *sha1, const uint8_t *data, size_t length) {
size_t i, j;
j = (size_t)((sha1->count >> 3) & 63);
sha1->count += (length << 3);
if ((j + length) > 63) {
i = 64 - j;
memcpy(&sha1->buffer[j], data, i);
sha1_transform(sha1, sha1->buffer);
for (; i + 63 < length; i += 64) {
sha1_transform(sha1, &data[i]);
}
j = 0;
} else {
i = 0;
}
memcpy(&sha1->buffer[j], &data[i], length - i);
}
void sha1_final(SHA1 *sha1, uint8_t digest[SHA1_DIGEST_LENGTH]) {
uint32_t i;
uint8_t count[8];
for (i = 0; i < 8; i++) {
// this is endian independent
count[i] = (uint8_t)((sha1->count >> ((7 - (i & 7)) * 8)) & 255);
}
sha1_update(sha1, (uint8_t *)"\200", 1);
while ((sha1->count & 504) != 448) {
sha1_update(sha1, (uint8_t *)"\0", 1);
}
sha1_update(sha1, count, 8);
for (i = 0; i < SHA1_DIGEST_LENGTH; i++) {
digest[i] = (uint8_t)((sha1->state[i >> 2] >> ((3 - (i & 3)) * 8)) & 255);
}
memset(sha1, 0, sizeof(*sha1));
}

25
wsockd/sha1.h Normal file
View file

@ -0,0 +1,25 @@
/*
* Based on the SHA-1 C implementation by Steve Reid <steve@edmweb.com>
* 100% Public Domain
*/
#ifndef SHA1_H
#define SHA1_H
#include <stddef.h>
#include <stdint.h>
#define SHA1_BLOCK_LENGTH 64
#define SHA1_DIGEST_LENGTH 20
typedef struct {
uint32_t state[5];
uint64_t count;
uint8_t buffer[SHA1_BLOCK_LENGTH];
} SHA1;
void sha1_init(SHA1 *sha1);
void sha1_update(SHA1 *sha1, const uint8_t *data, size_t length);
void sha1_final(SHA1 *sha1, uint8_t digest[SHA1_DIGEST_LENGTH]);
#endif // SHA1_H

View file

@ -21,12 +21,17 @@
*/
#include "stdinc.h"
#include "sha1.h"
#define MAXPASSFD 4
#ifndef READBUF_SIZE
#define READBUF_SIZE 16384
#endif
#define WEBSOCKET_SERVER_KEY "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
#define WEBSOCKET_ANSWER_STRING_1 "HTTP/1.1 101 Switching Protocols\r\nAccess-Control-Allow-Origin: *\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "
#define WEBSOCKET_ANSWER_STRING_2 "\r\n\r\n"
static void setup_signals(void);
static pid_t ppid;
@ -71,8 +76,8 @@ typedef struct _conn
rb_dlink_node node;
mod_ctl_t *ctl;
buf_head_t modbuf_out;
buf_head_t modbuf_in;
rawbuf_head_t *modbuf_out;
rawbuf_head_t *modbuf_in;
buf_head_t plainbuf_out;
buf_head_t plainbuf_in;
@ -86,23 +91,33 @@ typedef struct _conn
uint64_t plain_in;
uint64_t plain_out;
uint8_t flags;
char client_key[37]; /* maximum 36 bytes + nul */
} conn_t;
static void close_conn(conn_t * conn, int wait_plain, const char *fmt, ...);
static void conn_mod_read_cb(rb_fde_t *fd, void *data);
static void conn_plain_read_cb(rb_fde_t *fd, void *data);
#define FLAG_CORK 0x01
#define FLAG_DEAD 0x02
#define FLAG_WSOCK 0x04
#define FLAG_KEYED 0x08
#define IsCork(x) ((x)->flags & FLAG_CORK)
#define IsDead(x) ((x)->flags & FLAG_DEAD)
#define IsWS(x) ((x)->flags & FLAG_WSOCK)
#define IsKeyed(x) ((x)->flags & FLAG_KEYED)
#define SetCork(x) ((x)->flags |= FLAG_CORK)
#define SetDead(x) ((x)->flags |= FLAG_DEAD)
#define SetWS(x) ((x)->flags |= FLAG_WSOCK)
#define SetKeyed(x) ((x)->flags |= FLAG_KEYED)
#define ClearCork(x) ((x)->flags &= ~FLAG_CORK)
#define ClearDead(x) ((x)->flags &= ~FLAG_DEAD)
#define ClearWS(x) ((x)->flags &= ~FLAG_WSOCK)
#define ClearKeyed(x) ((x)->flags &= ~FLAG_KEYED)
#define NO_WAIT 0x0
#define WAIT_PLAIN 0x1
@ -112,6 +127,8 @@ typedef struct _conn
#define CONN_HASH_SIZE 2000
#define connid_hash(x) (&connid_hash_table[(x % CONN_HASH_SIZE)])
static const char *remote_closed = "Remote host closed the connection";
static rb_dlink_list connid_hash_table[CONN_HASH_SIZE];
static rb_dlink_list dead_list;
@ -196,8 +213,8 @@ free_conn(conn_t * conn)
rb_linebuf_donebuf(&conn->plainbuf_in);
rb_linebuf_donebuf(&conn->plainbuf_out);
rb_linebuf_donebuf(&conn->modbuf_in);
rb_linebuf_donebuf(&conn->modbuf_out);
rb_free_rawbuffer(conn->modbuf_in);
rb_free_rawbuffer(conn->modbuf_out);
rb_free(conn);
}
@ -217,6 +234,56 @@ clean_dead_conns(void *unused)
dead_list.tail = dead_list.head = NULL;
}
static void
conn_mod_write_sendq(rb_fde_t *fd, void *data)
{
conn_t *conn = data;
const char *err;
int retlen;
if(IsDead(conn))
return;
while((retlen = rb_rawbuf_flush(conn->modbuf_out, fd)) > 0)
conn->mod_out += retlen;
if(retlen == 0 || (retlen < 0 && !rb_ignore_errno(errno)))
{
if(retlen == 0)
close_conn(conn, WAIT_PLAIN, "%s", remote_closed);
err = strerror(errno);
close_conn(conn, WAIT_PLAIN, "Write error: %s", err);
return;
}
if(rb_rawbuf_length(conn->modbuf_out) > 0)
rb_setselect(conn->mod_fd, RB_SELECT_WRITE, conn_mod_write_sendq, conn);
else
rb_setselect(conn->mod_fd, RB_SELECT_WRITE, NULL, NULL);
if(IsCork(conn) && rb_rawbuf_length(conn->modbuf_out) == 0)
{
ClearCork(conn);
conn_plain_read_cb(conn->plain_fd, conn);
}
}
static void
conn_mod_write(conn_t * conn, void *data, size_t len)
{
if(IsDead(conn)) /* no point in queueing to a dead man */
return;
rb_rawbuf_append(conn->modbuf_out, data, len);
}
static void
conn_plain_write(conn_t * conn, void *data, size_t len)
{
if(IsDead(conn)) /* again no point in queueing to dead men */
return;
rb_linebuf_put(&conn->plainbuf_out, data, len);
}
static void
mod_write_ctl(rb_fde_t *F, void *data)
{
@ -270,7 +337,7 @@ close_conn(conn_t * conn, int wait_plain, const char *fmt, ...)
if(IsDead(conn))
return;
rb_linebuf_flush(conn->mod_fd, &conn->modbuf_out);
rb_rawbuf_flush(conn->modbuf_out, conn->mod_fd);
rb_linebuf_flush(conn->plain_fd, &conn->plainbuf_out);
rb_close(conn->mod_fd);
SetDead(conn);
@ -312,8 +379,8 @@ make_conn(mod_ctl_t * ctl, rb_fde_t *mod_fd, rb_fde_t *plain_fd)
rb_linebuf_newbuf(&conn->plainbuf_in);
rb_linebuf_newbuf(&conn->plainbuf_out);
rb_linebuf_newbuf(&conn->modbuf_in);
rb_linebuf_newbuf(&conn->modbuf_out);
conn->modbuf_in = rb_new_rawbuffer();
conn->modbuf_out = rb_new_rawbuffer();
return conn;
}
@ -335,10 +402,59 @@ conn_mod_handshake_process(conn_t *conn)
while (1)
{
size_t dolen = rb_linebuf_get(&conn->modbuf_in, inbuf, READBUF_SIZE, LINEBUF_COMPLETE, LINEBUF_PARSED);
char *p = NULL;
size_t dolen = rb_rawbuf_get(conn->modbuf_in, inbuf, sizeof inbuf);
if (!dolen)
break;
if ((p = strcasestr(inbuf, "Sec-WebSocket-Key:")) != NULL)
{
char *start, *end;
start = p + strlen("Sec-WebSocket-Key:");
for (; start < (inbuf + READBUF_SIZE) && *start; start++)
{
if (*start != ' ' && *start != '\t')
break;
}
for (end = start; end < (inbuf + READBUF_SIZE) && *end; end++)
{
if (*end == '\r' || *end == '\n')
{
*end = '\0';
break;
}
}
rb_strlcpy(conn->client_key, start, sizeof(conn->client_key));
SetKeyed(conn);
}
}
if (IsKeyed(conn))
{
SHA1 sha1;
uint8_t digest[SHA1_DIGEST_LENGTH];
char *resp;
sha1_init(&sha1);
sha1_update(&sha1, (uint8_t *) conn->client_key, strlen(conn->client_key));
sha1_update(&sha1, (uint8_t *) WEBSOCKET_SERVER_KEY, strlen(WEBSOCKET_SERVER_KEY));
sha1_final(&sha1, digest);
resp = (char *) rb_base64_encode(digest, SHA1_DIGEST_LENGTH);
conn_mod_write(conn, WEBSOCKET_ANSWER_STRING_1, strlen(WEBSOCKET_ANSWER_STRING_1));
conn_mod_write(conn, resp, strlen(resp));
conn_mod_write(conn, WEBSOCKET_ANSWER_STRING_2, strlen(WEBSOCKET_ANSWER_STRING_2));
rb_free(resp);
}
conn_mod_write_sendq(conn->mod_fd, conn);
}
static void
@ -375,7 +491,7 @@ conn_mod_handshake_cb(rb_fde_t *fd, void *data)
return;
}
int res = rb_linebuf_parse(&conn->modbuf_in, inbuf, length, 0);
rb_rawbuf_append(conn->modbuf_in, inbuf, length);
conn_mod_handshake_process(conn);
if (length < sizeof(inbuf))
@ -578,6 +694,7 @@ main(int argc, char **argv)
setup_signals();
rb_lib_init(NULL, NULL, NULL, 0, maxfd, 1024, 4096);
rb_linebuf_init(4096);
rb_init_rawbuffers(4096);
mod_ctl = rb_malloc(sizeof(mod_ctl_t));
mod_ctl->F = rb_open(ctlfd, RB_FD_SOCKET, "ircd control socket");