diff --git a/tunnel.c b/tunnel.c index 6a9ddb8..46f0db0 100644 --- a/tunnel.c +++ b/tunnel.c @@ -77,6 +77,54 @@ typedef struct static antd_tunnel_t g_tunnel; +static int guard_read(int fd, void* buffer, size_t size) +{ + int n = 0; + int read_len; + int st; + while(n != (int)size) + { + read_len = (int)size - n; + st = read(fd,buffer + n,read_len); + if(st == -1) + { + ERROR( "Unable to read from #%d: %s", fd, strerror(errno)); + return -1; + } + if(st == 0) + { + ERROR("Endpoint %d is closed", fd); + return -1; + } + n += st; + } + return n; +} + +static int guard_write(int fd, void* buffer, size_t size) +{ + int n = 0; + int write_len; + int st; + while(n != (int)size) + { + write_len = (int)size - n; + st = write(fd,buffer + n,write_len); + if(st == -1) + { + ERROR("Unable to write to #%d: %s", fd, strerror(errno)); + return -1; + } + if(st == 0) + { + ERROR("Endpoint %d is closed", fd); + return -1; + } + n += st; + } + return n; +} + static int mk_socket(const char *name, char *path) { struct sockaddr_un address; @@ -127,14 +175,14 @@ static int mk_socket(const char *name, char *path) static int msg_check_number(int fd, uint16_t number) { uint16_t value; - if (read(fd, &value, sizeof(value)) == -1) + if (guard_read(fd, &value, sizeof(value)) == -1) { ERROR("Unable to read integer value on socket %d: %s", fd, strerror(errno)); return -1; } if (number != value) { - ERROR("Value mismatches: %0x%02X, expected %0x%02X", value, number); + ERROR("Value mismatches: 0x%02X, expected 0x%02X", value, number); return -1; } return 0; @@ -143,7 +191,7 @@ static int msg_check_number(int fd, uint16_t number) static int msg_read_string(int fd, char* buffer, uint8_t max_length) { uint8_t size; - if(read(fd,&size,sizeof(size)) == -1) + if(guard_read(fd,&size,sizeof(size)) == -1) { ERROR("Unable to read string size: %s", strerror(errno)); return -1; @@ -153,7 +201,7 @@ static int msg_read_string(int fd, char* buffer, uint8_t max_length) ERROR("String length exceeds the maximal value of ", max_length); return -1; } - if(read(fd,buffer,size) == -1) + if(guard_read(fd,buffer,size) == -1) { ERROR("Unable to read string to buffer: %s", strerror(errno)); return -1; @@ -165,7 +213,7 @@ static int msg_read_string(int fd, char* buffer, uint8_t max_length) static uint8_t *msg_read_payload(int fd, uint32_t *size) { uint8_t *data; - if (read(fd, size, sizeof(*size)) == -1) + if (guard_read(fd, size, sizeof(*size)) == -1) { ERROR("Unable to read payload data size: %s", strerror(errno)); return NULL; @@ -181,7 +229,7 @@ static uint8_t *msg_read_payload(int fd, uint32_t *size) ERROR("Unable to allocate memory for payload data: %s", strerror(errno)); return NULL; } - if (read(fd, data, *size) == -1) + if (guard_read(fd, data, *size) == -1) { ERROR("Unable to read payload data to buffer: %s", strerror(errno)); free(data); @@ -198,7 +246,7 @@ static int msg_read(int fd, antd_tunnel_msg_t *msg) ERROR("Unable to check begin magic number on socket: %d", fd); return -1; } - if (read(fd, &msg->header.type, sizeof(msg->header.type)) == -1) + if (guard_read(fd, &msg->header.type, sizeof(msg->header.type)) == -1) { ERROR("Unable to read msg type: %s", strerror(errno)); return -1; @@ -208,12 +256,12 @@ static int msg_read(int fd, antd_tunnel_msg_t *msg) ERROR("Unknown msg type: %d", msg->header.type); return -1; } - if (read(fd, &msg->header.channel_id, sizeof(msg->header.channel_id)) == -1) + if (guard_read(fd, &msg->header.channel_id, sizeof(msg->header.channel_id)) == -1) { ERROR("Unable to read msg channel id"); return -1; } - if (read(fd, &msg->header.client_id, sizeof(msg->header.client_id)) == -1) + if (guard_read(fd, &msg->header.client_id, sizeof(msg->header.client_id)) == -1) { ERROR("Unable to read msg client id"); return -1; @@ -239,31 +287,31 @@ static int msg_write(int fd, antd_tunnel_msg_t *msg) { // write begin magic number uint16_t number = MSG_MAGIC_BEGIN; - if (write(fd, &number, sizeof(number)) == -1) + if (guard_write(fd, &number, sizeof(number)) == -1) { ERROR("Unable to write begin magic number: %s", strerror(errno)); return -1; } // write type - if (write(fd, &msg->header.type, sizeof(msg->header.type)) == -1) + if (guard_write(fd, &msg->header.type, sizeof(msg->header.type)) == -1) { ERROR("Unable to write msg type: %s", strerror(errno)); return -1; } // write channel id - if (write(fd, &msg->header.channel_id, sizeof(msg->header.channel_id)) == -1) + if (guard_write(fd, &msg->header.channel_id, sizeof(msg->header.channel_id)) == -1) { ERROR("Unable to write msg channel id: %s", strerror(errno)); return -1; } //write client id - if (write(fd, &msg->header.client_id, sizeof(msg->header.client_id)) == -1) + if (guard_write(fd, &msg->header.client_id, sizeof(msg->header.client_id)) == -1) { ERROR("Unable to write msg client id: %s", strerror(errno)); return -1; } // write payload len - if (write(fd, &msg->header.size, sizeof(msg->header.size)) == -1) + if (guard_write(fd, &msg->header.size, sizeof(msg->header.size)) == -1) { ERROR("Unable to write msg payload length: %s", strerror(errno)); return -1; @@ -271,14 +319,14 @@ static int msg_write(int fd, antd_tunnel_msg_t *msg) // write payload data if (msg->header.size > 0) { - if (write(fd, msg->data, msg->header.size) == -1) + if (guard_write(fd, msg->data, msg->header.size) == -1) { ERROR("Unable to write msg payload: %s", strerror(errno)); return -1; } } number = MSG_MAGIC_END; - if (write(fd, &number, sizeof(number)) == -1) + if (guard_write(fd, &number, sizeof(number)) == -1) { ERROR("Unable to write end magic number: %s", strerror(errno)); return -1;