feat: tls-psk

This commit is contained in:
unknown 2023-02-16 09:52:18 +09:00
parent 480aef1626
commit 842380caf6
5 changed files with 230 additions and 3 deletions

View File

@ -45,6 +45,9 @@ Directories:
* oneshot * oneshot
A simple connection example. A simple connection example.
* psk
PSK(Pre Shared Key) support.
* sni * sni
Support to SNI (Server Name Indication). Support to SNI (Server Name Indication).

29
samples/psk/client.lua Normal file
View File

@ -0,0 +1,29 @@
--
-- Public domain
--
local socket = require("socket")
local ssl = require("ssl")
local params = {
mode = "client",
protocol = "tlsv1_2",
psk = function(hint, max_psk_len)
print("PSK Callback: hint=", hint, ", max_psk_len=", max_psk_len)
return "abcd", "1234"
end
}
local peer = socket.tcp()
peer:connect("127.0.0.1", 8888)
peer = assert( ssl.wrap(peer, params) )
assert(peer:dohandshake())
print("--- INFO ---")
local info = peer:info()
for k, v in pairs(info) do
print(k, v)
end
print("---")
peer:close()

42
samples/psk/server.lua Normal file
View File

@ -0,0 +1,42 @@
--
-- Public domain
--
local socket = require("socket")
local ssl = require("ssl")
local params = {
mode = "server",
protocol = "any",
options = "all",
psk = function(identity, max_psk_len)
print("PSK Callback: identity=", identity, ", max_psk_len=", max_psk_len)
if identity == "abcd" then
return "1234"
end
return nil
end
}
-- [[ SSL context
local ctx = assert(ssl.newcontext(params))
--]]
local server = socket.tcp()
server:setoption('reuseaddr', true)
assert( server:bind("127.0.0.1", 8888) )
server:listen()
local peer = server:accept()
peer = assert( ssl.wrap(peer, ctx) )
assert( peer:dohandshake() )
print("--- INFO ---")
local info = peer:info()
for k, v in pairs(info) do
print(k, v)
end
print("---")
peer:close()
server:close()

View File

@ -708,6 +708,124 @@ static int set_alpn_cb(lua_State *L)
return 1; return 1;
} }
static unsigned int server_psk_cb(
SSL *ssl,
const char *identity,
unsigned char *psk,
unsigned int max_psk_len
)
{
int psk_len;
const char *ret_psk;
SSL_CTX *ctx = SSL_get_SSL_CTX(ssl);
p_context pctx = (p_context)SSL_CTX_get_app_data(ctx);
lua_State *L = pctx->L;
luaL_getmetatable(L, "SSL:PSKServer:Registry");
lua_pushlightuserdata(L, (void*)pctx->context);
lua_gettable(L, -2);
lua_pushlstring(L, identity, strlen(identity));
lua_pushnumber(L, max_psk_len, max_psk_len);
lua_call(L, 2, 1);
if (!lua_isstring(L, -1)) {
lua_pop(L, 2);
return 0;
}
ret_psk = luaL_checklstring(L, -1, &psk_len);
if (psk_len > max_psk_len) psk_len = max_psk_len;
memcpy(psk, ret_psk, psk_len);
lua_pop(L, 2);
return psk_len;
}
/**
* Set a psk callback for server.
*/
static int set_server_psk_cb(lua_State *L) {
p_context ctx = checkctx(L, 1);
luaL_getmetatable(L, "SSL:PSKServer:Registry");
lua_pushlightuserdata(L, (void*)ctx->context);
lua_pushvalue(L, 2);
lua_settable(L, -3);
SSL_CTX_set_psk_server_callback(ctx->context, server_psk_cb);
lua_pushboolean(L, 1);
return 1;
}
static unsigned int client_psk_cb(
SSL *ssl, const char *hint,
char *identity, unsigned int max_identity_len,
unsigned char *psk, unsigned int max_psk_len
)
{
int identity_len;
const char *ret_identity;
int psk_len;
const char *ret_psk;
SSL_CTX *ctx = SSL_get_SSL_CTX(ssl);
p_context pctx = (p_context)SSL_CTX_get_app_data(ctx);
lua_State *L = pctx->L;
luaL_getmetatable(L, "SSL:PSKClient:Registry");
lua_pushlightuserdata(L, (void*)pctx->context);
lua_gettable(L, -2);
if (hint) {
lua_pushlstring(L, hint, strlen(hint));
} else {
lua_pushlstring(L, "", 0);
}
lua_pushnumber(L, max_psk_len, max_psk_len);
lua_call(L, 2, 2);
if (!lua_isstring(L, -1) || !lua_isstring(L, -2)) {
lua_pop(L, 2);
return 0;
}
ret_identity = luaL_checklstring(L, -2, &identity_len);
ret_psk = luaL_checklstring(L, -1, &psk_len);
if (identity_len > max_identity_len) identity_len = max_identity_len;
if (psk_len > max_psk_len) psk_len = max_psk_len;
memcpy(identity, ret_identity, identity_len);
memcpy(psk, ret_psk, psk_len);
lua_pop(L, 2);
return psk_len;
}
/**
* Set a psk callback for client.
*/
static int set_client_psk_cb(lua_State *L) {
p_context ctx = checkctx(L, 1);
luaL_getmetatable(L, "SSL:PSKClient:Registry");
lua_pushlightuserdata(L, (void*)ctx->context);
lua_pushvalue(L, 2);
lua_settable(L, -3);
SSL_CTX_set_psk_client_callback(ctx->context, client_psk_cb);
lua_pushboolean(L, 1);
return 1;
}
#if defined(LSEC_ENABLE_DANE) #if defined(LSEC_ENABLE_DANE)
/* /*
* DANE * DANE
@ -759,6 +877,8 @@ static luaL_Reg funcs[] = {
{"setdhparam", set_dhparam}, {"setdhparam", set_dhparam},
{"setverify", set_verify}, {"setverify", set_verify},
{"setoptions", set_options}, {"setoptions", set_options},
{"setserverpskcb", set_server_psk_cb},
{"setclientpskcb", set_client_psk_cb},
{"setmode", set_mode}, {"setmode", set_mode},
#if !defined(OPENSSL_NO_EC) #if !defined(OPENSSL_NO_EC)
{"setcurve", set_curve}, {"setcurve", set_curve},
@ -792,6 +912,14 @@ static int meth_destroy(lua_State *L)
lua_pushlightuserdata(L, (void*)ctx->context); lua_pushlightuserdata(L, (void*)ctx->context);
lua_pushnil(L); lua_pushnil(L);
lua_settable(L, -3); lua_settable(L, -3);
luaL_getmetatable(L, "SSL:PSKServer:Registry");
lua_pushlightuserdata(L, (void*)ctx->context);
lua_pushnil(L);
lua_settable(L, -3);
luaL_getmetatable(L, "SSL:PSKClient:Registry");
lua_pushlightuserdata(L, (void*)ctx->context);
lua_pushnil(L);
lua_settable(L, -3);
SSL_CTX_free(ctx->context); SSL_CTX_free(ctx->context);
ctx->context = NULL; ctx->context = NULL;
@ -934,9 +1062,11 @@ void *lsec_testudata (lua_State *L, int ud, const char *tname) {
*/ */
LSEC_API int luaopen_ssl_context(lua_State *L) LSEC_API int luaopen_ssl_context(lua_State *L)
{ {
luaL_newmetatable(L, "SSL:DH:Registry"); /* Keep all DH callbacks */ luaL_newmetatable(L, "SSL:DH:Registry"); /* Keep all DH callbacks */
luaL_newmetatable(L, "SSL:ALPN:Registry"); /* Keep all ALPN callbacks */ luaL_newmetatable(L, "SSL:ALPN:Registry"); /* Keep all ALPN callbacks */
luaL_newmetatable(L, "SSL:Verify:Registry"); /* Keep all verify flags */ luaL_newmetatable(L, "SSL:PSKServer:Registry"); /* Keep all PSK callbacks */
luaL_newmetatable(L, "SSL:PSKClient:Registry"); /* Keep all PSK callbacks */
luaL_newmetatable(L, "SSL:Verify:Registry"); /* Keep all verify flags */
luaL_newmetatable(L, "SSL:Context"); luaL_newmetatable(L, "SSL:Context");
setfuncs(L, meta); setfuncs(L, meta);

View File

@ -201,6 +201,29 @@ local function newcontext(cfg)
if not succ then return nil, msg end if not succ then return nil, msg end
end end
-- PSK
if cfg.psk then
if type(cfg.psk) == "function" then
local pskcb = cfg.psk
if cfg.mode == "client" then
succ, msg = context.setclientpskcb(ctx, function(hint, max_psk_len)
local identity, psk = pskcb(hint, max_psk_len)
return identity, psk
end)
if not succ then return nil, msg end
else
succ, msg = context.setserverpskcb(ctx, function(identity, max_psk_len)
local psk = pskcb(identity, max_psk_len)
return psk
end)
if not succ then return nil, msg end
end
else
return nil, "invalid PSK Callback parameter"
end
end
if config.capabilities.dane and cfg.dane then if config.capabilities.dane and cfg.dane then
if type(cfg.dane) == "table" then if type(cfg.dane) == "table" then
context.setdane(ctx, unpack(cfg.dane)) context.setdane(ctx, unpack(cfg.dane))