Fix PSK client callback

This commit is contained in:
Bruno Silvestre 2023-02-16 10:28:34 -03:00
parent 9b09c93249
commit dd8ba1fc92

View File

@ -763,16 +763,16 @@ static int set_server_psk_cb(lua_State *L)
return 1; return 1;
} }
static unsigned int client_psk_cb( /*
SSL *ssl, const char *hint, * Client callback to PSK.
char *identity, unsigned int max_identity_len, */
unsigned char *psk, unsigned int max_psk_len static unsigned int client_psk_cb(SSL *ssl, const char *hint, char *identity, unsigned int max_identity_len,
) unsigned char *psk, unsigned int max_psk_len)
{ {
int identity_len; size_t psk_len;
const char *ret_identity; size_t identity_len;
int psk_len;
const char *ret_psk; const char *ret_psk;
const char *ret_identity;
SSL_CTX *ctx = SSL_get_SSL_CTX(ssl); SSL_CTX *ctx = SSL_get_SSL_CTX(ssl);
p_context pctx = (p_context)SSL_CTX_get_app_data(ctx); p_context pctx = (p_context)SSL_CTX_get_app_data(ctx);
lua_State *L = pctx->L; lua_State *L = pctx->L;
@ -781,30 +781,34 @@ static unsigned int client_psk_cb(
lua_pushlightuserdata(L, (void*)pctx->context); lua_pushlightuserdata(L, (void*)pctx->context);
lua_gettable(L, -2); lua_gettable(L, -2);
if (hint) { if (hint)
lua_pushlstring(L, hint, strlen(hint)); lua_pushstring(L, hint);
} else { else
lua_pushlstring(L, "", 0); lua_pushnil(L);
}
lua_pushnumber(L, max_psk_len, max_psk_len);
lua_call(L, 2, 2); // Leave space to '\0'
lua_pushinteger(L, max_identity_len-1);
lua_pushinteger(L, max_psk_len);
lua_call(L, 3, 2);
if (!lua_isstring(L, -1) || !lua_isstring(L, -2)) { if (!lua_isstring(L, -1) || !lua_isstring(L, -2)) {
lua_pop(L, 2); lua_pop(L, 2);
return 0; return 0;
} }
ret_identity = luaL_checklstring(L, -2, &identity_len); ret_identity = luaL_tolstring(L, -2, &identity_len);
ret_psk = luaL_checklstring(L, -1, &psk_len); ret_psk = luaL_tolstring(L, -1, &psk_len);
if (identity_len > max_identity_len) identity_len = max_identity_len;
if (psk_len > max_psk_len) psk_len = max_psk_len;
if (ret_identity >= max_identity_len || psk_len > max_psk_len)
psk_len = 0;
else {
memcpy(identity, ret_identity, identity_len); memcpy(identity, ret_identity, identity_len);
identity[ret_identity] = 0;
memcpy(psk, ret_psk, psk_len); memcpy(psk, ret_psk, psk_len);
}
lua_pop(L, 2); lua_pop(L, 3);
return psk_len; return psk_len;
} }