1
0
mirror of https://github.com/lxsang/antd-tunnel-plugin synced 2024-11-16 09:48:21 +01:00

improve tunnel API

This commit is contained in:
lxsang 2020-12-28 21:36:38 +01:00
parent 061a3c418e
commit 17c4396474

View File

@ -77,6 +77,54 @@ typedef struct
static antd_tunnel_t g_tunnel; static antd_tunnel_t g_tunnel;
static int guard_read(int fd, void* buffer, size_t size)
{
int n = 0;
int read_len;
int st;
while(n != (int)size)
{
read_len = (int)size - n;
st = read(fd,buffer + n,read_len);
if(st == -1)
{
ERROR( "Unable to read from #%d: %s", fd, strerror(errno));
return -1;
}
if(st == 0)
{
ERROR("Endpoint %d is closed", fd);
return -1;
}
n += st;
}
return n;
}
static int guard_write(int fd, void* buffer, size_t size)
{
int n = 0;
int write_len;
int st;
while(n != (int)size)
{
write_len = (int)size - n;
st = write(fd,buffer + n,write_len);
if(st == -1)
{
ERROR("Unable to write to #%d: %s", fd, strerror(errno));
return -1;
}
if(st == 0)
{
ERROR("Endpoint %d is closed", fd);
return -1;
}
n += st;
}
return n;
}
static int mk_socket(const char *name, char *path) static int mk_socket(const char *name, char *path)
{ {
struct sockaddr_un address; struct sockaddr_un address;
@ -127,14 +175,14 @@ static int mk_socket(const char *name, char *path)
static int msg_check_number(int fd, uint16_t number) static int msg_check_number(int fd, uint16_t number)
{ {
uint16_t value; uint16_t value;
if (read(fd, &value, sizeof(value)) == -1) if (guard_read(fd, &value, sizeof(value)) == -1)
{ {
ERROR("Unable to read integer value on socket %d: %s", fd, strerror(errno)); ERROR("Unable to read integer value on socket %d: %s", fd, strerror(errno));
return -1; return -1;
} }
if (number != value) if (number != value)
{ {
ERROR("Value mismatches: %0x%02X, expected %0x%02X", value, number); ERROR("Value mismatches: 0x%02X, expected 0x%02X", value, number);
return -1; return -1;
} }
return 0; return 0;
@ -143,7 +191,7 @@ static int msg_check_number(int fd, uint16_t number)
static int msg_read_string(int fd, char* buffer, uint8_t max_length) static int msg_read_string(int fd, char* buffer, uint8_t max_length)
{ {
uint8_t size; uint8_t size;
if(read(fd,&size,sizeof(size)) == -1) if(guard_read(fd,&size,sizeof(size)) == -1)
{ {
ERROR("Unable to read string size: %s", strerror(errno)); ERROR("Unable to read string size: %s", strerror(errno));
return -1; return -1;
@ -153,7 +201,7 @@ static int msg_read_string(int fd, char* buffer, uint8_t max_length)
ERROR("String length exceeds the maximal value of ", max_length); ERROR("String length exceeds the maximal value of ", max_length);
return -1; return -1;
} }
if(read(fd,buffer,size) == -1) if(guard_read(fd,buffer,size) == -1)
{ {
ERROR("Unable to read string to buffer: %s", strerror(errno)); ERROR("Unable to read string to buffer: %s", strerror(errno));
return -1; return -1;
@ -165,7 +213,7 @@ static int msg_read_string(int fd, char* buffer, uint8_t max_length)
static uint8_t *msg_read_payload(int fd, uint32_t *size) static uint8_t *msg_read_payload(int fd, uint32_t *size)
{ {
uint8_t *data; uint8_t *data;
if (read(fd, size, sizeof(*size)) == -1) if (guard_read(fd, size, sizeof(*size)) == -1)
{ {
ERROR("Unable to read payload data size: %s", strerror(errno)); ERROR("Unable to read payload data size: %s", strerror(errno));
return NULL; return NULL;
@ -181,7 +229,7 @@ static uint8_t *msg_read_payload(int fd, uint32_t *size)
ERROR("Unable to allocate memory for payload data: %s", strerror(errno)); ERROR("Unable to allocate memory for payload data: %s", strerror(errno));
return NULL; return NULL;
} }
if (read(fd, data, *size) == -1) if (guard_read(fd, data, *size) == -1)
{ {
ERROR("Unable to read payload data to buffer: %s", strerror(errno)); ERROR("Unable to read payload data to buffer: %s", strerror(errno));
free(data); free(data);
@ -198,7 +246,7 @@ static int msg_read(int fd, antd_tunnel_msg_t *msg)
ERROR("Unable to check begin magic number on socket: %d", fd); ERROR("Unable to check begin magic number on socket: %d", fd);
return -1; return -1;
} }
if (read(fd, &msg->header.type, sizeof(msg->header.type)) == -1) if (guard_read(fd, &msg->header.type, sizeof(msg->header.type)) == -1)
{ {
ERROR("Unable to read msg type: %s", strerror(errno)); ERROR("Unable to read msg type: %s", strerror(errno));
return -1; return -1;
@ -208,12 +256,12 @@ static int msg_read(int fd, antd_tunnel_msg_t *msg)
ERROR("Unknown msg type: %d", msg->header.type); ERROR("Unknown msg type: %d", msg->header.type);
return -1; return -1;
} }
if (read(fd, &msg->header.channel_id, sizeof(msg->header.channel_id)) == -1) if (guard_read(fd, &msg->header.channel_id, sizeof(msg->header.channel_id)) == -1)
{ {
ERROR("Unable to read msg channel id"); ERROR("Unable to read msg channel id");
return -1; return -1;
} }
if (read(fd, &msg->header.client_id, sizeof(msg->header.client_id)) == -1) if (guard_read(fd, &msg->header.client_id, sizeof(msg->header.client_id)) == -1)
{ {
ERROR("Unable to read msg client id"); ERROR("Unable to read msg client id");
return -1; return -1;
@ -239,31 +287,31 @@ static int msg_write(int fd, antd_tunnel_msg_t *msg)
{ {
// write begin magic number // write begin magic number
uint16_t number = MSG_MAGIC_BEGIN; uint16_t number = MSG_MAGIC_BEGIN;
if (write(fd, &number, sizeof(number)) == -1) if (guard_write(fd, &number, sizeof(number)) == -1)
{ {
ERROR("Unable to write begin magic number: %s", strerror(errno)); ERROR("Unable to write begin magic number: %s", strerror(errno));
return -1; return -1;
} }
// write type // write type
if (write(fd, &msg->header.type, sizeof(msg->header.type)) == -1) if (guard_write(fd, &msg->header.type, sizeof(msg->header.type)) == -1)
{ {
ERROR("Unable to write msg type: %s", strerror(errno)); ERROR("Unable to write msg type: %s", strerror(errno));
return -1; return -1;
} }
// write channel id // write channel id
if (write(fd, &msg->header.channel_id, sizeof(msg->header.channel_id)) == -1) if (guard_write(fd, &msg->header.channel_id, sizeof(msg->header.channel_id)) == -1)
{ {
ERROR("Unable to write msg channel id: %s", strerror(errno)); ERROR("Unable to write msg channel id: %s", strerror(errno));
return -1; return -1;
} }
//write client id //write client id
if (write(fd, &msg->header.client_id, sizeof(msg->header.client_id)) == -1) if (guard_write(fd, &msg->header.client_id, sizeof(msg->header.client_id)) == -1)
{ {
ERROR("Unable to write msg client id: %s", strerror(errno)); ERROR("Unable to write msg client id: %s", strerror(errno));
return -1; return -1;
} }
// write payload len // write payload len
if (write(fd, &msg->header.size, sizeof(msg->header.size)) == -1) if (guard_write(fd, &msg->header.size, sizeof(msg->header.size)) == -1)
{ {
ERROR("Unable to write msg payload length: %s", strerror(errno)); ERROR("Unable to write msg payload length: %s", strerror(errno));
return -1; return -1;
@ -271,14 +319,14 @@ static int msg_write(int fd, antd_tunnel_msg_t *msg)
// write payload data // write payload data
if (msg->header.size > 0) if (msg->header.size > 0)
{ {
if (write(fd, msg->data, msg->header.size) == -1) if (guard_write(fd, msg->data, msg->header.size) == -1)
{ {
ERROR("Unable to write msg payload: %s", strerror(errno)); ERROR("Unable to write msg payload: %s", strerror(errno));
return -1; return -1;
} }
} }
number = MSG_MAGIC_END; number = MSG_MAGIC_END;
if (write(fd, &number, sizeof(number)) == -1) if (guard_write(fd, &number, sizeof(number)) == -1)
{ {
ERROR("Unable to write end magic number: %s", strerror(errno)); ERROR("Unable to write end magic number: %s", strerror(errno));
return -1; return -1;