diff --git a/src/context.c b/src/context.c index bf7b493..8e2b736 100644 --- a/src/context.c +++ b/src/context.c @@ -708,12 +708,10 @@ static int set_alpn_cb(lua_State *L) return 1; } -static unsigned int server_psk_cb( - SSL *ssl, - const char *identity, - unsigned char *psk, - unsigned int max_psk_len -) +/** + * Callback to select the PSK. + */ +static unsigned int server_psk_cb(SSL *ssl, const char *identity, unsigned char *psk, unsigned int max_psk_len) { int psk_len; const char *ret_psk; @@ -725,21 +723,18 @@ static unsigned int server_psk_cb( lua_pushlightuserdata(L, (void*)pctx->context); lua_gettable(L, -2); - lua_pushlstring(L, identity, strlen(identity)); - lua_pushnumber(L, max_psk_len, max_psk_len); + lua_pushstring(L, identity); + lua_pushinteger(L, max_psk_len); lua_call(L, 2, 1); - if (!lua_isstring(L, -1)) { + if (!lua_isstring(L, -1) || lua_objlen(L, -1) == 0) { lua_pop(L, 2); return 0; } ret_psk = luaL_checklstring(L, -1, &psk_len); - - if (psk_len > max_psk_len) psk_len = max_psk_len; - - memcpy(psk, ret_psk, psk_len); + memcpy(psk, ret_psk, (psk_len > max_psk_len) ? max_psk_len : psk_len); lua_pop(L, 2);