diff --git a/tunnel.c b/tunnel.c index 0eacc9a..7a1019b 100644 --- a/tunnel.c +++ b/tunnel.c @@ -269,8 +269,73 @@ static int msg_write(int fd, antd_tunnel_msg_t* msg) } return 0; } +static void write_msg_to_client(antd_tunnel_msg_t* msg, antd_client_t* client) +{ + uint8_t* buffer; + int long_value = 0; + int offset = 0; + long_value = msg->header.size + + sizeof((int)MSG_MAGIC_BEGIN) + + sizeof(msg->header.type) + + sizeof(msg->header.channel_id) + + sizeof(msg->header.client_id) + + sizeof(msg->header.size) + + sizeof((int)MSG_MAGIC_END); + buffer = (uint8_t*) malloc(long_value); + if(buffer == NULL) + { + ERROR("unable to allocate memory for write"); + return; + } + // magic + long_value = (int) MSG_MAGIC_BEGIN; + (void)memcpy(buffer,&long_value,sizeof(long_value)); + offset += sizeof(long_value); + // type + (void)memcpy(buffer+offset,&msg->header.type,sizeof(msg->header.type)); + offset += sizeof(msg->header.type); + // channel id + (void)memcpy(buffer+offset,&msg->header.channel_id,sizeof(msg->header.channel_id)); + offset += sizeof(msg->header.channel_id); + // client id + (void)memcpy(buffer+offset,&msg->header.client_id,sizeof(msg->header.client_id)); + offset += sizeof(msg->header.client_id); + // payload length + (void)memcpy(buffer+offset,&msg->header.size,sizeof(msg->header.size)); + offset += sizeof(msg->header.size); + // payload + (void)memcpy(buffer+offset,msg->data,msg->header.size); + offset += msg->header.size; + // magic end + long_value = (int) MSG_MAGIC_END; + (void)memcpy(buffer+offset,&long_value,sizeof(long_value)); + offset += sizeof(long_value); + + // write it to the websocket + ws_b(client,buffer, offset); + + free(buffer); + +} +static void unsubscribe(bst_node_t* node, void** argv, int argc) +{ + // request client to unsubscribe + (void) argc; + antd_tunnel_channel_t* channel = (antd_tunnel_channel_t*) argv[0]; + antd_tunnel_msg_t msg; + if(node->data != NULL) + { + msg.header.channel_id = simple_hash(channel->name); + msg.header.client_id = node->key; + msg.header.type = CHANNEL_UNSUBSCRIBE; + msg.header.size = 0; + msg.data = NULL; + write_msg_to_client(&msg,(antd_client_t*)node->data); + } +} static void destroy_channel(antd_tunnel_channel_t* channel) { + void* argc[1]; if(channel == NULL) return; if(channel->sock != -1) @@ -278,7 +343,8 @@ static void destroy_channel(antd_tunnel_channel_t* channel) (void) close(channel->sock); channel->sock = -1; } - /** TODO: send message to all subcribers before close*/ + argc[0] = (void*)channel; + bst_for_each(channel->subscribers,unsubscribe,argc, 1); bst_free(channel->subscribers); free(channel); } @@ -351,7 +417,6 @@ static void channel_open(int fd, const char* name) static void channel_close(antd_tunnel_channel_t* channel) { antd_tunnel_msg_t msg; - msg.data = NULL; msg.header.channel_id = 0; msg.header.client_id = 0; @@ -370,7 +435,6 @@ static void channel_close(antd_tunnel_channel_t* channel) LOG("Close channel: %s (%d)", channel->name, channel->sock); destroy_channel(channel); } - //g_tunnel.channels = bst_delete(g_tunnel.channels, node->key); } } static void monitor_hotline(int listen_fd) @@ -430,66 +494,12 @@ static void monitor_hotline(int listen_fd) break; } } -static void write_msg_to_client(antd_tunnel_msg_t* msg, antd_client_t* client) -{ - uint8_t* buffer; - int long_value = 0; - int offset = 0; - long_value = msg->header.size + - sizeof((int)MSG_MAGIC_BEGIN) + - sizeof(msg->header.type) + - sizeof(msg->header.channel_id) + - sizeof(msg->header.client_id) + - sizeof(msg->header.size) + - msg->header.size + - sizeof((int)MSG_MAGIC_END); - buffer = (uint8_t*) malloc(long_value); - if(buffer == NULL) - { - ERROR("unable to allocate memory for write"); - return; - } - // magic - long_value = (int) MSG_MAGIC_BEGIN; - (void)memcpy(buffer,&long_value,sizeof(long_value)); - offset += sizeof(long_value); - - // type - (void)memcpy(buffer+offset,&msg->header.type,sizeof(msg->header.type)); - offset += sizeof(msg->header.type); - - // channel id - (void)memcpy(buffer+offset,&msg->header.channel_id,sizeof(msg->header.channel_id)); - offset += sizeof(msg->header.channel_id); - - // client id - (void)memcpy(buffer+offset,&msg->header.client_id,sizeof(msg->header.client_id)); - offset += sizeof(msg->header.client_id); - - // payload length - (void)memcpy(buffer+offset,&msg->header.size,sizeof(msg->header.size)); - offset += sizeof(msg->header.size); - - // payload - (void)memcpy(buffer+offset,&msg->data,sizeof(msg->header.size)); - offset += msg->header.size; - - // magic end - long_value = (int) MSG_MAGIC_END; - (void)memcpy(buffer,&long_value,sizeof(long_value)); - offset += sizeof(long_value); - - // write it to the websocket - ws_b(client,buffer, offset); - - free(buffer); - -} static void handle_channel(bst_node_t* node, void** args, int argc) { antd_tunnel_msg_t msg; (void) argc; fd_set* fd_in = (fd_set*) args[0]; + list_t* channel_list = (list_t*) args[1]; antd_tunnel_channel_t* channel = (antd_tunnel_channel_t*) node->data; bst_node_t * client; antd_client_t* rq; @@ -517,9 +527,14 @@ static void handle_channel(bst_node_t* node, void** args, int argc) case CHANNEL_OK: case CHANNEL_ERROR: case CHANNEL_DATA: + case CHANNEL_UNSUBSCRIBE: // forward message to the correct client in the channel msg.header.channel_id = node->key; client = bst_find(channel->subscribers, msg.header.client_id); + if(msg.header.type == CHANNEL_UNSUBSCRIBE) + { + channel->subscribers = bst_delete(channel->subscribers, msg.header.client_id); + } if(client != NULL) { rq = (antd_client_t*) client->data; @@ -533,10 +548,12 @@ static void handle_channel(bst_node_t* node, void** args, int argc) ERROR("Unable to find client %d to write on channel %s", msg.header.client_id, channel->name); } break; + case CHANNEL_CLOSE: // close the current channel channel_close(channel); node->data = NULL; + list_put_ptr(channel_list, node); break; default: LOG("Message type %d is not supported in client-application communication", msg.header.type); @@ -567,6 +584,8 @@ static void* multiplex(void* data_p) struct timeval timeout; int rc; void *args[2]; + list_t closed_channels; + item_t item; antd_tunnel_t* tunnel_p = (antd_tunnel_t*) data_p; while(status == 0) { @@ -588,8 +607,9 @@ static void* multiplex(void* data_p) status = 1; break; case 0: - // time out - // sleep here + timeout.tv_sec = 0; + timeout.tv_usec = 500; // 5 ms + select(0, NULL, NULL, NULL, &timeout); break; // we have data default: @@ -599,8 +619,16 @@ static void* multiplex(void* data_p) monitor_hotline(tunnel_p->hotline); } pthread_mutex_lock(&tunnel_p->lock); + closed_channels = list_init(); args[0] = (void*) &fd_in; - bst_for_each(tunnel_p->channels, handle_channel,args, 1); + args[1] = (void*) &closed_channels; + bst_for_each(tunnel_p->channels, handle_channel,args, 2); + list_for_each(item, closed_channels) + { + tunnel_p->channels = bst_delete(tunnel_p->channels, ((bst_node_t*)item->value.ptr)->key); + item->value.ptr = NULL; + } + list_free(&closed_channels); pthread_mutex_unlock(&tunnel_p->lock); } } @@ -699,9 +727,16 @@ static void process_client_message(antd_tunnel_msg_t* msg, antd_client_t* client write_msg_to_client(msg, client); return; } - (void)memcpy(buff, msg->data, msg->header.size); - buff[msg->header.size] = '\0'; - hash_val = simple_hash(buff); + if(msg->header.size > 0) + { + (void)memcpy(buff, msg->data, msg->header.size); + buff[msg->header.size] = '\0'; + hash_val = simple_hash(buff); + } + else + { + hash_val = msg->header.channel_id; + } node = bst_find(g_tunnel.channels, hash_val); if(node) { @@ -717,7 +752,10 @@ static void process_client_message(antd_tunnel_msg_t* msg, antd_client_t* client msg->header.channel_id = hash_val; msg->header.size = sizeof(g_tunnel.id_allocator); (void)memcpy(buff, &g_tunnel.id_allocator, sizeof(g_tunnel.id_allocator)); + msg->data = (uint8_t*)buff; write_msg_to_client(msg, client); + msg->header.client_id = g_tunnel.id_allocator; + msg->header.type = CHANNEL_SUBSCRIBE; } else { @@ -726,17 +764,27 @@ static void process_client_message(antd_tunnel_msg_t* msg, antd_client_t* client msg->header.channel_id = hash_val; msg->header.size = 0; write_msg_to_client(msg, client); + msg->header.type = CHANNEL_UNSUBSCRIBE; + } + // forward to publisher + + if(msg_write(channel->sock, msg) == -1) + { + ERROR("Unable to forward subscribe/unsubscribe message to %s", channel->name); } } } else { - msg->header.type = CHANNEL_ERROR; (void) snprintf(buff, BUFFLEN, "Channel not found"); msg->header.size = strlen(buff); msg->data = (uint8_t*)buff; ERROR("%s", buff); - write_msg_to_client(msg, client); + if(msg->header.type == CHANNEL_SUBSCRIBE) + { + msg->header.type = CHANNEL_ERROR; + write_msg_to_client(msg, client); + } } break; @@ -745,18 +793,73 @@ static void process_client_message(antd_tunnel_msg_t* msg, antd_client_t* client break; } } + +static void unsubscribe_notify_handle(bst_node_t* node, void** argv, int argc) +{ + (void) argc; + antd_client_t* client = (antd_client_t*)argv[0]; + antd_tunnel_channel_t* channel = (antd_tunnel_channel_t*)argv[1]; + list_t* list = (list_t*)argv[2]; + antd_tunnel_msg_t msg; + if((antd_client_t*)node->data == client) + { + if(channel != NULL) + { + msg.header.type = CHANNEL_UNSUBSCRIBE; + msg.header.channel_id = simple_hash(channel->name); + msg.header.client_id = node->key; + msg.header.size = 0; + msg.data = NULL; + if(msg_write(channel->sock, &msg) == -1) + { + ERROR("Unable to send unsubscribe notification of client %d to channel %s (%d)", node->key, channel->name, channel->sock); + } + } + if(list != NULL) + { + list_put_ptr(list, node); + } + } +} + +static void unsubscribe_notify(bst_node_t* node, void** argv, int argc) +{ + (void)argc; + void * pargv[3]; + antd_client_t* client = (antd_client_t*) argv[0]; + antd_tunnel_channel_t* channel = (antd_tunnel_channel_t*) node->data; + list_t list = list_init(); + item_t item; + if(channel != NULL) + { + pargv[0] = (void*) client; + pargv[1] = (void*) channel; + pargv[2] = (void*) &list; + bst_for_each(channel->subscribers,unsubscribe_notify_handle,pargv, 3); + list_for_each(item, list) + { + channel->subscribers = bst_delete(channel->subscribers, ((bst_node_t*)item->value.ptr)->key); + item->value.ptr = NULL; + } + } + list_free(&list); + +} + void *handle(void *rq_data) { antd_request_t *rq = (antd_request_t *)rq_data; antd_task_t *task = antd_create_task(NULL, (void *)rq, NULL, time(NULL)); ws_msg_header_t *h = NULL; antd_tunnel_msg_t msg; - struct timeval timeout; uint8_t* buffer; + struct timeval timeout; + int status; fd_set fd_in; - int rc, long_value, offset; + int long_value, offset; task->priority++; - int cl_fd = ((antd_client_t *)rq->client)->sock; + + void * argv[1]; if(g_tunnel.initialized == 0) { ERROR("The tunnel plugin is not initialised correctly"); @@ -765,115 +868,120 @@ void *handle(void *rq_data) if (ws_enable(rq->request)) { timeout.tv_sec = 0; - timeout.tv_usec = 500; + timeout.tv_usec = 500; // 5 ms FD_ZERO(&fd_in); - FD_SET(cl_fd, &fd_in); - rc = select(cl_fd + 1, &fd_in, NULL, NULL, &timeout); - switch (rc) - { - case -1: - LOG("Error on select(): %s\n", strerror(errno)); - ws_close(rq->client, 1011); - /** TODO: remove all subscriber of this ws connection */ - return task; - case 0: - // time out - break; - // we have data - default: - pthread_mutex_lock(&g_tunnel.lock); - h = ws_read_header(rq->client); - pthread_mutex_unlock(&g_tunnel.lock); - if (h) - { - if (h->mask == 0) + FD_SET(((antd_client_t*)(rq->client))->sock, &fd_in); + status = select(((antd_client_t*)(rq->client))->sock + 1, &fd_in, NULL, NULL, &timeout); + switch (status) + { + case -1: + LOG("Error %d on select()\n", errno); + break; + case 0: + timeout.tv_sec = 0; + timeout.tv_usec = 500; // 5 ms + select(0, NULL, NULL, NULL, &timeout); + break; + default: + argv[0] = (void*) rq->client; + pthread_mutex_lock(&g_tunnel.lock); + h = ws_read_header(rq->client); + pthread_mutex_unlock(&g_tunnel.lock); + if (h) { - LOG("Data is not mask"); - // kill the child process - free(h); - pthread_mutex_lock(&g_tunnel.lock); - ws_close(rq->client, 1011); - pthread_mutex_unlock(&g_tunnel.lock); - /** TODO: remove all subscriber of this ws connection */ - return task; - } - if (h->opcode == WS_CLOSE) - { - LOG("Websocket: connection closed"); - pthread_mutex_lock(&g_tunnel.lock); - ws_close(rq->client, 1011); - pthread_mutex_unlock(&g_tunnel.lock); - free(h); - /** TODO: remove all subscriber of this ws connection */ - return task; - } - if (h->opcode == WS_BIN) - { - // we have data, now read the message, - // the message must be in bin - buffer = (uint8_t*) malloc(h->plen + 1); - if(buffer) + if (h->mask == 0) { + LOG("Data is not mask"); + // kill the child process + free(h); pthread_mutex_lock(&g_tunnel.lock); - rc = ws_read_data(rq->client,h, h->plen, buffer); + ws_close(rq->client, 1011); + bst_for_each(g_tunnel.channels, unsubscribe_notify, argv, 1); pthread_mutex_unlock(&g_tunnel.lock); - if(h->plen == 0) - { - offset = 0; - // verify begin magic - (void)memcpy(&long_value, buffer,sizeof(long_value)); - offset += sizeof(long_value); - if(long_value != MSG_MAGIC_BEGIN) - { - ERROR("Invalid begin magic number: %d, expected %d", long_value, MSG_MAGIC_BEGIN); - free(buffer); - goto reschedule_task; - } - // msgtype - (void) memcpy(&msg.header.type, buffer + offset, sizeof(msg.header.type)); - offset += sizeof(msg.header.type); - - // channel id - (void) memcpy(&msg.header.channel_id, buffer + offset, sizeof(msg.header.channel_id)); - offset += sizeof(msg.header.channel_id); - - // client id - (void) memcpy(&msg.header.client_id, buffer + offset, sizeof(msg.header.client_id)); - offset += sizeof(msg.header.client_id); - - // data size - (void) memcpy(&msg.header.size, buffer + offset, sizeof(msg.header.size)); - offset += sizeof(msg.header.size); - - // data - msg.data = buffer + offset; - offset += msg.header.size; - - // verify end magic - (void)memcpy(&long_value, buffer + offset ,sizeof(long_value)); - offset += sizeof(long_value); - if(long_value != MSG_MAGIC_END) - { - ERROR("Invalid end magic number: %d, expected %d", long_value, MSG_MAGIC_END); - free(buffer); - goto reschedule_task; - } - - // now we have the message - pthread_mutex_lock(&g_tunnel.lock); - process_client_message(&msg, rq->client); - pthread_mutex_unlock(&g_tunnel.lock); - } - free(buffer); + return task; } + if (h->opcode == WS_CLOSE) + { + LOG("Websocket: connection closed"); + pthread_mutex_lock(&g_tunnel.lock); + //ws_close(rq->client, 1011); + bst_for_each(g_tunnel.channels, unsubscribe_notify, argv, 1); + pthread_mutex_unlock(&g_tunnel.lock); + free(h); + return task; + } + if (h->opcode == WS_BIN) + { + // we have data, now read the message, + // the message must be in bin + buffer = (uint8_t*) malloc(h->plen + 1); + if(buffer) + { + pthread_mutex_lock(&g_tunnel.lock); + ws_read_data(rq->client,h, h->plen, buffer); + pthread_mutex_unlock(&g_tunnel.lock); + if(h->plen == 0) + { + offset = 0; + // verify begin magic + (void)memcpy(&long_value, buffer,sizeof(long_value)); + offset += sizeof(long_value); + if(long_value != MSG_MAGIC_BEGIN) + { + ERROR("Invalid begin magic number: %d, expected %d", long_value, MSG_MAGIC_BEGIN); + free(buffer); + goto reschedule_task; + } + // msgtype + (void) memcpy(&msg.header.type, buffer + offset, sizeof(msg.header.type)); + offset += sizeof(msg.header.type); + + // channel id + (void) memcpy(&msg.header.channel_id, buffer + offset, sizeof(msg.header.channel_id)); + offset += sizeof(msg.header.channel_id); + + // client id + (void) memcpy(&msg.header.client_id, buffer + offset, sizeof(msg.header.client_id)); + offset += sizeof(msg.header.client_id); + + // data size + (void) memcpy(&msg.header.size, buffer + offset, sizeof(msg.header.size)); + offset += sizeof(msg.header.size); + + // data + msg.data = buffer + offset; + offset += msg.header.size; + + // verify end magic + (void)memcpy(&long_value, buffer + offset ,sizeof(long_value)); + offset += sizeof(long_value); + if(long_value != MSG_MAGIC_END) + { + ERROR("Invalid end magic number: %d, expected %d", long_value, MSG_MAGIC_END); + free(buffer); + goto reschedule_task; + } + + // now we have the message + pthread_mutex_lock(&g_tunnel.lock); + process_client_message(&msg, rq->client); + pthread_mutex_unlock(&g_tunnel.lock); + } + free(buffer); + } + } + free(h); } - free(h); - } } } + else + { + return task; + } reschedule_task: task->handle = handle; task->type = HEAVY; task->access_time = time(NULL); + select(0, NULL, NULL, NULL, &timeout); return task; } \ No newline at end of file