From dd8ba1fc9248625c45ca69b019ac9b4fa6f97224 Mon Sep 17 00:00:00 2001 From: Bruno Silvestre Date: Thu, 16 Feb 2023 10:28:34 -0300 Subject: [PATCH] Fix PSK client callback --- src/context.c | 50 +++++++++++++++++++++++++++----------------------- 1 file changed, 27 insertions(+), 23 deletions(-) diff --git a/src/context.c b/src/context.c index b4c761b..8bbcd92 100644 --- a/src/context.c +++ b/src/context.c @@ -763,16 +763,16 @@ static int set_server_psk_cb(lua_State *L) return 1; } -static unsigned int client_psk_cb( - SSL *ssl, const char *hint, - char *identity, unsigned int max_identity_len, - unsigned char *psk, unsigned int max_psk_len -) +/* + * Client callback to PSK. + */ +static unsigned int client_psk_cb(SSL *ssl, const char *hint, char *identity, unsigned int max_identity_len, + unsigned char *psk, unsigned int max_psk_len) { - int identity_len; - const char *ret_identity; - int psk_len; + size_t psk_len; + size_t identity_len; const char *ret_psk; + const char *ret_identity; SSL_CTX *ctx = SSL_get_SSL_CTX(ssl); p_context pctx = (p_context)SSL_CTX_get_app_data(ctx); lua_State *L = pctx->L; @@ -781,30 +781,34 @@ static unsigned int client_psk_cb( lua_pushlightuserdata(L, (void*)pctx->context); lua_gettable(L, -2); - if (hint) { - lua_pushlstring(L, hint, strlen(hint)); - } else { - lua_pushlstring(L, "", 0); - } - lua_pushnumber(L, max_psk_len, max_psk_len); + if (hint) + lua_pushstring(L, hint); + else + lua_pushnil(L); - lua_call(L, 2, 2); + // Leave space to '\0' + lua_pushinteger(L, max_identity_len-1); + lua_pushinteger(L, max_psk_len); + + lua_call(L, 3, 2); if (!lua_isstring(L, -1) || !lua_isstring(L, -2)) { lua_pop(L, 2); return 0; } - ret_identity = luaL_checklstring(L, -2, &identity_len); - ret_psk = luaL_checklstring(L, -1, &psk_len); + ret_identity = luaL_tolstring(L, -2, &identity_len); + ret_psk = luaL_tolstring(L, -1, &psk_len); - if (identity_len > max_identity_len) identity_len = max_identity_len; - if (psk_len > max_psk_len) psk_len = max_psk_len; + if (ret_identity >= max_identity_len || psk_len > max_psk_len) + psk_len = 0; + else { + memcpy(identity, ret_identity, identity_len); + identity[ret_identity] = 0; + memcpy(psk, ret_psk, psk_len); + } - memcpy(identity, ret_identity, identity_len); - memcpy(psk, ret_psk, psk_len); - - lua_pop(L, 2); + lua_pop(L, 3); return psk_len; }