improve websocket lib

This commit is contained in:
lxsang 2020-12-22 16:02:36 +01:00
parent 708e492c49
commit 4b6de31e66
3 changed files with 301 additions and 264 deletions

Binary file not shown.

142
lib/ws.c
View File

@ -37,8 +37,10 @@ ws_msg_header_t * ws_read_header(void* client)
ws_msg_header_t *header = (ws_msg_header_t *)malloc(sizeof(*header)); ws_msg_header_t *header = (ws_msg_header_t *)malloc(sizeof(*header));
// get first byte // get first byte
if(antd_recv(client, &byte, sizeof(byte)) <0) goto fail; if (antd_recv(client, &byte, sizeof(byte)) < 0)
if(BITV(byte,6) || BITV(byte,5) || BITV(byte,4)) goto fail;// all RSV bit must be 0 goto fail;
if (BITV(byte, 6) || BITV(byte, 5) || BITV(byte, 4))
goto fail; // all RSV bit must be 0
//printf("FIN: %d, RSV1: %d, RSV2: %d, RSV3:%d, opcode:%d\n", BITV(byte,7), BITV(byte,6), BITV(byte,5), BITV(byte,4),(byte & 0x0F) ); //printf("FIN: %d, RSV1: %d, RSV2: %d, RSV3:%d, opcode:%d\n", BITV(byte,7), BITV(byte,6), BITV(byte,5), BITV(byte,4),(byte & 0x0F) );
// find and opcode // find and opcode
@ -46,7 +48,8 @@ ws_msg_header_t * ws_read_header(void* client)
header->opcode = (byte & 0x0F); header->opcode = (byte & 0x0F);
// get next byte // get next byte
if(antd_recv(client, &byte, sizeof(byte)) <0) goto fail; if (antd_recv(client, &byte, sizeof(byte)) < 0)
goto fail;
//printf("MASK: %d paylen:%d\n", BITV(byte,7), (byte & 0x7F)); //printf("MASK: %d paylen:%d\n", BITV(byte,7), (byte & 0x7F));
// check mask bit, should be 1 // check mask bit, should be 1
@ -62,47 +65,52 @@ ws_msg_header_t * ws_read_header(void* client)
if (len <= 125) if (len <= 125)
{ {
header->plen = len; header->plen = len;
} else if(len == 126) }
else if (len == 126)
{ {
if(antd_recv(client,bytes, 2*sizeof(uint8_t)) <0) goto fail; if (antd_recv(client, bytes, 2 * sizeof(uint8_t)) < 0)
goto fail;
header->plen = (bytes[0] << 8) + bytes[1]; header->plen = (bytes[0] << 8) + bytes[1];
}
} else else
{ {
//read only last 4 byte //read only last 4 byte
if(antd_recv(client,bytes, 8*sizeof(uint8_t)) <0) goto fail; if (antd_recv(client, bytes, 8 * sizeof(uint8_t)) < 0)
goto fail;
header->plen = (bytes[4] << 24) + (bytes[5] << 16) + (bytes[6] << 8) + bytes[7]; header->plen = (bytes[4] << 24) + (bytes[5] << 16) + (bytes[6] << 8) + bytes[7];
} }
//printf("len: %d\n", header->plen); //printf("len: %d\n", header->plen);
// last step is to get the maskey // last step is to get the maskey
if (header->mask) if (header->mask)
if(antd_recv(client,header->mask_key, 4*sizeof(uint8_t)) <0) goto fail; if (antd_recv(client, header->mask_key, 4 * sizeof(uint8_t)) < 0)
goto fail;
//printf("key 0: %d key 1: %d key2:%d, key3: %d\n",header->mask_key[0],header->mask_key[1],header->mask_key[2], header->mask_key[3] ); //printf("key 0: %d key 1: %d key2:%d, key3: %d\n",header->mask_key[0],header->mask_key[1],header->mask_key[2], header->mask_key[3] );
// check wheather it is a ping or a close message // check wheather it is a ping or a close message
// process it and return NULL // process it and return NULL
//otherwise return the header //otherwise return the header
//return the header //return the header
switch(header->opcode){ switch (header->opcode)
{
case WS_CLOSE: // client requests to close the connection case WS_CLOSE: // client requests to close the connection
// send back a close message // send back a close message
ws_send_close(client,1000,header->mask?0:1); UNUSED(ws_send_close(client, 1000, header->mask ? 0 : 1));
//goto fail; //goto fail;
break; break;
case WS_PING: // client send a ping case WS_PING: // client send a ping
// send back a pong message // send back a pong message
ws_pong(client,header, header->mask?0:1 ); UNUSED(ws_pong(client, header, header->mask ? 0 : 1));
break; break;
default: break; default:
break;
} }
return header; return header;
fail: fail:
free(header); free(header);
return NULL; return NULL;
} }
/** /**
* Read data from client * Read data from client
@ -111,9 +119,11 @@ ws_msg_header_t * ws_read_header(void* client)
int ws_read_data(void *client, ws_msg_header_t *header, int len, uint8_t *data) int ws_read_data(void *client, ws_msg_header_t *header, int len, uint8_t *data)
{ {
// if len == -1 ==> read all remaining data to 'data'; // if len == -1 ==> read all remaining data to 'data';
if(header->plen == 0) return 0; if (header->plen == 0)
return 0;
int dlen = (len == -1 || len > (int)header->plen) ? (int)header->plen : len; int dlen = (len == -1 || len > (int)header->plen) ? (int)header->plen : len;
if((dlen = antd_recv(client,data, dlen)) <0) return -1; if ((dlen = antd_recv(client, data, dlen)) < 0)
return -1;
header->plen = header->plen - dlen; header->plen = header->plen - dlen;
// unmask received data // unmask received data
if (header->mask) if (header->mask)
@ -122,15 +132,17 @@ int ws_read_data(void* client, ws_msg_header_t* header, int len, uint8_t* data)
data[dlen] = '\0'; data[dlen] = '\0';
return dlen; return dlen;
} }
void _send_header(void* client, ws_msg_header_t header) int _send_header(void *client, ws_msg_header_t header)
{ {
uint8_t byte = 0; uint8_t byte = 0;
uint8_t bytes[8]; uint8_t bytes[8];
for(int i=0; i< 8; i++) bytes[i] = 0; for (int i = 0; i < 8; i++)
bytes[i] = 0;
//first byte |FIN|000|opcode| //first byte |FIN|000|opcode|
byte = (header.fin << 7) + header.opcode; byte = (header.fin << 7) + header.opcode;
//printf("BYTE: %d\n", byte); //printf("BYTE: %d\n", byte);
antd_send(client, &byte, 1); if (antd_send(client, &byte, 1) != 1)
return -1;
// second byte, payload length // second byte, payload length
// mask may be 0 or 1 // mask may be 0 or 1
//if(header.mask == 1) //if(header.mask == 1)
@ -138,15 +150,18 @@ void _send_header(void* client, ws_msg_header_t header)
if (header.plen <= 125) if (header.plen <= 125)
{ {
byte = (header.mask << 7) + header.plen; byte = (header.mask << 7) + header.plen;
antd_send(client, &byte, 1); if (antd_send(client, &byte, 1) != 1)
return -1;
} }
else if (header.plen < 65536) // 16 bits else if (header.plen < 65536) // 16 bits
{ {
byte = (header.mask << 7) + 126; byte = (header.mask << 7) + 126;
bytes[0] = (header.plen) >> 8; bytes[0] = (header.plen) >> 8;
bytes[1] = (header.plen) & 0x00FF; bytes[1] = (header.plen) & 0x00FF;
antd_send(client, &byte, 1); if (antd_send(client, &byte, 1) != 1)
antd_send(client, &bytes, 2); return -1;
if (antd_send(client, bytes, 2) != 2)
return -1;
} }
else // > 16 bits else // > 16 bits
{ {
@ -155,22 +170,27 @@ void _send_header(void* client, ws_msg_header_t header)
bytes[5] = ((header.plen) >> 16) & 0x00FF; bytes[5] = ((header.plen) >> 16) & 0x00FF;
bytes[6] = ((header.plen) >> 8) & 0x00FF; bytes[6] = ((header.plen) >> 8) & 0x00FF;
bytes[7] = (header.plen) & 0x00FF; bytes[7] = (header.plen) & 0x00FF;
antd_send(client, &byte, 1); if (antd_send(client, &byte, 1) != 1)
antd_send(client, &bytes, 8); return -1;
if (antd_send(client, bytes, 8) != 8)
return -1;
} }
// send mask key // send mask key
if (header.mask) if (header.mask)
{ {
antd_send(client, header.mask_key,4); if (antd_send(client, header.mask_key, 4) != 4)
return -1;
} }
return 0;
} }
/** /**
* Send a frame to client * Send a frame to client
*/ */
void ws_send_frame(void* client, uint8_t* data, ws_msg_header_t header) int ws_send_frame(void *client, uint8_t *data, ws_msg_header_t header)
{ {
uint8_t *masked; uint8_t *masked;
masked = data; masked = data;
int ret;
if (header.mask) if (header.mask)
{ {
ws_gen_mask_key(&header); ws_gen_mask_key(&header);
@ -178,18 +198,24 @@ void ws_send_frame(void* client, uint8_t* data, ws_msg_header_t header)
for (int i = 0; i < (int)header.plen; ++i) for (int i = 0; i < (int)header.plen; ++i)
masked[i] = data[i] ^ header.mask_key[i % 4]; masked[i] = data[i] ^ header.mask_key[i % 4];
} }
_send_header(client, header); if (_send_header(client, header) != 0)
return -1;
if (header.opcode == WS_TEXT) if (header.opcode == WS_TEXT)
antd_send(client,(char*)masked,header.plen); ret = antd_send(client, (char *)masked, header.plen);
else else
antd_send(client,(uint8_t*)masked,header.plen); ret = antd_send(client, (uint8_t *)masked, header.plen);
if (masked && header.mask) if (masked && header.mask)
free(masked); free(masked);
if (ret != (int)header.plen)
{
return -1;
}
return 0;
} }
/** /**
* send a text data frame to client * send a text data frame to client
*/ */
void ws_send_text(void* client, const char* data,int mask) int ws_send_text(void *client, const char *data, int mask)
{ {
ws_msg_header_t header; ws_msg_header_t header;
header.fin = 1; header.fin = 1;
@ -198,40 +224,40 @@ void ws_send_text(void* client, const char* data,int mask)
header.plen = strlen(data); header.plen = strlen(data);
//_send_header(client,header); //_send_header(client,header);
//send(client, data, header.plen,0); //send(client, data, header.plen,0);
ws_send_frame(client,(uint8_t*)data,header); return ws_send_frame(client, (uint8_t *)data, header);
} }
/** /**
* send a single binary data fram to client * send a single binary data fram to client
* not tested yet, but should work * not tested yet, but should work
*/ */
void ws_send_binary(void* client, uint8_t* data, int l, int mask) int ws_send_binary(void *client, uint8_t *data, int l, int mask)
{ {
ws_msg_header_t header; ws_msg_header_t header;
header.fin = 1; header.fin = 1;
header.opcode = WS_BIN; header.opcode = WS_BIN;
header.plen = l; header.plen = l;
header.mask = mask; header.mask = mask;
ws_send_frame(client,data, header); return ws_send_frame(client, data, header);
//_send_header(client,header); //_send_header(client,header);
//send(client, data, header.plen,0); //send(client, data, header.plen,0);
} }
/* /*
* send a file as binary data * send a file as binary data
*/ */
void ws_send_file(void* client, const char* file, int mask) int ws_send_file(void *client, const char *file, int mask)
{ {
uint8_t buff[1024]; uint8_t buff[1024];
FILE *ptr; FILE *ptr;
ptr = fopen(file, "rb"); ptr = fopen(file, "rb");
if (!ptr) if (!ptr)
{ {
ws_send_close(client,1011,mask); return ws_send_close(client, 1011, mask);
return;
} }
ws_msg_header_t header; ws_msg_header_t header;
size_t size; size_t size;
int first_frame = 1; int first_frame = 1;
int ret = 0;
//ws_send_frame(client,buff,header); //ws_send_frame(client,buff,header);
header.mask = mask; header.mask = mask;
while (!feof(ptr)) while (!feof(ptr))
@ -251,48 +277,56 @@ void ws_send_file(void* client, const char* file, int mask)
header.opcode = 0; header.opcode = 0;
header.plen = size; header.plen = size;
//printf("FIN: %d OC:%d\n", header.fin, header.opcode); //printf("FIN: %d OC:%d\n", header.fin, header.opcode);
ws_send_frame(client,buff,header); ret += ws_send_frame(client, buff, header);
} }
fclose(ptr); fclose(ptr);
if (ret != 0)
{
return -1;
}
return 0;
} }
/** /**
* Not tested yet * Not tested yet
* but should work * but should work
*/ */
void ws_pong(void* client, ws_msg_header_t* oheader, int mask) int ws_pong(void *client, ws_msg_header_t *oheader, int mask)
{ {
ws_msg_header_t pheader; ws_msg_header_t pheader;
int ret;
pheader.fin = 1; pheader.fin = 1;
pheader.opcode = WS_PONG; pheader.opcode = WS_PONG;
pheader.plen = oheader->plen; pheader.plen = oheader->plen;
pheader.mask = mask; pheader.mask = mask;
uint8_t *data = (uint8_t *)malloc(oheader->plen); uint8_t *data = (uint8_t *)malloc(oheader->plen);
if(!data) return; if (!data)
return -1;
if (ws_read_data(client, oheader, pheader.plen, data) == -1) if (ws_read_data(client, oheader, pheader.plen, data) == -1)
{ {
ERROR("Cannot read ping data %d", pheader.plen); ERROR("Cannot read ping data %d", pheader.plen);
free(data); free(data);
return; return -1;
} }
ws_send_frame(client,data,pheader); ret = ws_send_frame(client, data, pheader);
free(data); free(data);
//_send_header(client, pheader); //_send_header(client, pheader);
//send(client, data, len, 0); //send(client, data, len, 0);
return ret;
} }
void ws_ping(void* client, const char* echo, int mask) int ws_ping(void *client, const char *echo, int mask)
{ {
ws_msg_header_t pheader; ws_msg_header_t pheader;
pheader.fin = 1; pheader.fin = 1;
pheader.opcode = WS_PING; pheader.opcode = WS_PING;
pheader.plen = strlen(echo); pheader.plen = strlen(echo);
pheader.mask = mask; pheader.mask = mask;
ws_send_frame(client,(uint8_t*)echo,pheader); return ws_send_frame(client, (uint8_t *)echo, pheader);
} }
/* /*
* Not tested yet, but should work * Not tested yet, but should work
*/ */
void ws_send_close(void* client, unsigned int status, int mask) int ws_send_close(void *client, unsigned int status, int mask)
{ {
//printf("CLOSED\n"); //printf("CLOSED\n");
ws_msg_header_t header; ws_msg_header_t header;
@ -310,7 +344,7 @@ void ws_send_close(void* client, unsigned int status, int mask)
header.mask_key[1] = bytes[1]; header.mask_key[1] = bytes[1];
bytes[0] = bytes[1] ^ bytes[1]; bytes[0] = bytes[1] ^ bytes[1];
}*/ }*/
ws_send_frame(client,bytes,header); return ws_send_frame(client, bytes, header);
//_send_header(client, header); //_send_header(client, header);
//send(client,bytes,2,0); //send(client,bytes,2,0);
} }
@ -436,7 +470,8 @@ int ws_client_connect(ws_client_t* wsclient, port_config_t pcnf)
} }
wsclient->ssl_ctx = SSL_CTX_new(method); wsclient->ssl_ctx = SSL_CTX_new(method);
ssl_err = ERR_get_error(); ssl_err = ERR_get_error();
if (!wsclient->ssl_ctx) { if (!wsclient->ssl_ctx)
{
ERROR("SSL_CTX_new: %s", ERR_error_string(ssl_err, NULL)); ERROR("SSL_CTX_new: %s", ERR_error_string(ssl_err, NULL));
return -1; return -1;
} }
@ -460,19 +495,22 @@ int ws_client_connect(ws_client_t* wsclient, port_config_t pcnf)
if (wsclient->sslcert && wsclient->sslkey) if (wsclient->sslcert && wsclient->sslkey)
{ {
if (SSL_CTX_use_certificate_file(wsclient->ssl_ctx,wsclient->sslcert, SSL_FILETYPE_PEM) <= 0) { if (SSL_CTX_use_certificate_file(wsclient->ssl_ctx, wsclient->sslcert, SSL_FILETYPE_PEM) <= 0)
{
ssl_err = ERR_get_error(); ssl_err = ERR_get_error();
ERROR("SSL_CTX_use_certificate_file: %s", ERR_error_string(ssl_err, NULL)); ERROR("SSL_CTX_use_certificate_file: %s", ERR_error_string(ssl_err, NULL));
return -1; return -1;
} }
if (wsclient->sslpasswd) if (wsclient->sslpasswd)
SSL_CTX_set_default_passwd_cb_userdata(wsclient->ssl_ctx, (void *)wsclient->sslpasswd); SSL_CTX_set_default_passwd_cb_userdata(wsclient->ssl_ctx, (void *)wsclient->sslpasswd);
if (SSL_CTX_use_PrivateKey_file(wsclient->ssl_ctx,wsclient->sslkey, SSL_FILETYPE_PEM) <= 0) { if (SSL_CTX_use_PrivateKey_file(wsclient->ssl_ctx, wsclient->sslkey, SSL_FILETYPE_PEM) <= 0)
{
ssl_err = ERR_get_error(); ssl_err = ERR_get_error();
ERROR("SSL_CTX_use_PrivateKey_file: %s", ERR_error_string(ssl_err, NULL)); ERROR("SSL_CTX_use_PrivateKey_file: %s", ERR_error_string(ssl_err, NULL));
return -1; return -1;
} }
if (SSL_CTX_check_private_key(wsclient->ssl_ctx) == 0) { if (SSL_CTX_check_private_key(wsclient->ssl_ctx) == 0)
{
ssl_err = ERR_get_error(); ssl_err = ERR_get_error();
ERROR("SSL_CTX_check_private_key: %s", ERR_error_string(ssl_err, NULL)); ERROR("SSL_CTX_check_private_key: %s", ERR_error_string(ssl_err, NULL));
return -1; return -1;
@ -528,9 +566,6 @@ int ws_client_connect(ws_client_t* wsclient, port_config_t pcnf)
return 0; return 0;
} }
int ws_open_handshake(ws_client_t *client) int ws_open_handshake(ws_client_t *client)
{ {
char buf[MAX_BUFF]; char buf[MAX_BUFF];
@ -562,7 +597,8 @@ int ws_open_handshake(ws_client_t* client)
{ {
//LOG("Handshake sucessfull\n"); //LOG("Handshake sucessfull\n");
done = 1; done = 1;
} else }
else
{ {
ERROR("WS handshake, Wrong key %s vs %s", token, SERVER_WS_KEY); ERROR("WS handshake, Wrong key %s vs %s", token, SERVER_WS_KEY);
return -1; return -1;

View File

@ -22,7 +22,8 @@
#define CLIENT_RQ "GET /%s HTTP/1.1\r\nHost: %s\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n" #define CLIENT_RQ "GET /%s HTTP/1.1\r\nHost: %s\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n"
#define SERVER_WS_KEY "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=" #define SERVER_WS_KEY "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="
typedef struct{ typedef struct
{
uint8_t fin; uint8_t fin;
uint8_t opcode; uint8_t opcode;
unsigned int plen; unsigned int plen;
@ -30,7 +31,8 @@ typedef struct{
uint8_t mask_key[4]; uint8_t mask_key[4];
} ws_msg_header_t; } ws_msg_header_t;
typedef struct{ typedef struct
{
const char *host; const char *host;
const char *resource; const char *resource;
antd_client_t *antdsock; antd_client_t *antdsock;
@ -43,17 +45,16 @@ typedef struct{
void *ssl_ctx; void *ssl_ctx;
} ws_client_t; } ws_client_t;
ws_msg_header_t *ws_read_header(void *); ws_msg_header_t *ws_read_header(void *);
void ws_send_frame(void* , uint8_t* , ws_msg_header_t ); int ws_send_frame(void *, uint8_t *, ws_msg_header_t);
void ws_pong(void* client, ws_msg_header_t*, int mask); int ws_pong(void *client, ws_msg_header_t *, int mask);
void ws_ping(void* client, const char* echo, int mask); int ws_ping(void *client, const char *echo, int mask);
void ws_send_text(void* client, const char* data,int mask); int ws_send_text(void *client, const char *data, int mask);
void ws_send_close(void* client, unsigned int status, int mask); int ws_send_close(void *client, unsigned int status, int mask);
void ws_send_file(void* client, const char* file, int mask); int ws_send_file(void *client, const char *file, int mask);
void ws_send_binary(void* client, uint8_t* data, int l, int mask); int ws_send_binary(void *client, uint8_t *data, int l, int mask);
int ws_read_data(void *, ws_msg_header_t *, int, uint8_t *); int ws_read_data(void *, ws_msg_header_t *, int, uint8_t *);
int request_socket(const char *ip, int port); int request_socket(const char *ip, int port);