feat: tls-psk

This commit is contained in:
unknown 2023-02-16 09:52:18 +09:00
parent 480aef1626
commit 842380caf6
5 changed files with 230 additions and 3 deletions

View File

@ -45,6 +45,9 @@ Directories:
* oneshot
A simple connection example.
* psk
PSK(Pre Shared Key) support.
* sni
Support to SNI (Server Name Indication).

29
samples/psk/client.lua Normal file
View File

@ -0,0 +1,29 @@
--
-- Public domain
--
local socket = require("socket")
local ssl = require("ssl")
local params = {
mode = "client",
protocol = "tlsv1_2",
psk = function(hint, max_psk_len)
print("PSK Callback: hint=", hint, ", max_psk_len=", max_psk_len)
return "abcd", "1234"
end
}
local peer = socket.tcp()
peer:connect("127.0.0.1", 8888)
peer = assert( ssl.wrap(peer, params) )
assert(peer:dohandshake())
print("--- INFO ---")
local info = peer:info()
for k, v in pairs(info) do
print(k, v)
end
print("---")
peer:close()

42
samples/psk/server.lua Normal file
View File

@ -0,0 +1,42 @@
--
-- Public domain
--
local socket = require("socket")
local ssl = require("ssl")
local params = {
mode = "server",
protocol = "any",
options = "all",
psk = function(identity, max_psk_len)
print("PSK Callback: identity=", identity, ", max_psk_len=", max_psk_len)
if identity == "abcd" then
return "1234"
end
return nil
end
}
-- [[ SSL context
local ctx = assert(ssl.newcontext(params))
--]]
local server = socket.tcp()
server:setoption('reuseaddr', true)
assert( server:bind("127.0.0.1", 8888) )
server:listen()
local peer = server:accept()
peer = assert( ssl.wrap(peer, ctx) )
assert( peer:dohandshake() )
print("--- INFO ---")
local info = peer:info()
for k, v in pairs(info) do
print(k, v)
end
print("---")
peer:close()
server:close()

View File

@ -708,6 +708,124 @@ static int set_alpn_cb(lua_State *L)
return 1;
}
static unsigned int server_psk_cb(
SSL *ssl,
const char *identity,
unsigned char *psk,
unsigned int max_psk_len
)
{
int psk_len;
const char *ret_psk;
SSL_CTX *ctx = SSL_get_SSL_CTX(ssl);
p_context pctx = (p_context)SSL_CTX_get_app_data(ctx);
lua_State *L = pctx->L;
luaL_getmetatable(L, "SSL:PSKServer:Registry");
lua_pushlightuserdata(L, (void*)pctx->context);
lua_gettable(L, -2);
lua_pushlstring(L, identity, strlen(identity));
lua_pushnumber(L, max_psk_len, max_psk_len);
lua_call(L, 2, 1);
if (!lua_isstring(L, -1)) {
lua_pop(L, 2);
return 0;
}
ret_psk = luaL_checklstring(L, -1, &psk_len);
if (psk_len > max_psk_len) psk_len = max_psk_len;
memcpy(psk, ret_psk, psk_len);
lua_pop(L, 2);
return psk_len;
}
/**
* Set a psk callback for server.
*/
static int set_server_psk_cb(lua_State *L) {
p_context ctx = checkctx(L, 1);
luaL_getmetatable(L, "SSL:PSKServer:Registry");
lua_pushlightuserdata(L, (void*)ctx->context);
lua_pushvalue(L, 2);
lua_settable(L, -3);
SSL_CTX_set_psk_server_callback(ctx->context, server_psk_cb);
lua_pushboolean(L, 1);
return 1;
}
static unsigned int client_psk_cb(
SSL *ssl, const char *hint,
char *identity, unsigned int max_identity_len,
unsigned char *psk, unsigned int max_psk_len
)
{
int identity_len;
const char *ret_identity;
int psk_len;
const char *ret_psk;
SSL_CTX *ctx = SSL_get_SSL_CTX(ssl);
p_context pctx = (p_context)SSL_CTX_get_app_data(ctx);
lua_State *L = pctx->L;
luaL_getmetatable(L, "SSL:PSKClient:Registry");
lua_pushlightuserdata(L, (void*)pctx->context);
lua_gettable(L, -2);
if (hint) {
lua_pushlstring(L, hint, strlen(hint));
} else {
lua_pushlstring(L, "", 0);
}
lua_pushnumber(L, max_psk_len, max_psk_len);
lua_call(L, 2, 2);
if (!lua_isstring(L, -1) || !lua_isstring(L, -2)) {
lua_pop(L, 2);
return 0;
}
ret_identity = luaL_checklstring(L, -2, &identity_len);
ret_psk = luaL_checklstring(L, -1, &psk_len);
if (identity_len > max_identity_len) identity_len = max_identity_len;
if (psk_len > max_psk_len) psk_len = max_psk_len;
memcpy(identity, ret_identity, identity_len);
memcpy(psk, ret_psk, psk_len);
lua_pop(L, 2);
return psk_len;
}
/**
* Set a psk callback for client.
*/
static int set_client_psk_cb(lua_State *L) {
p_context ctx = checkctx(L, 1);
luaL_getmetatable(L, "SSL:PSKClient:Registry");
lua_pushlightuserdata(L, (void*)ctx->context);
lua_pushvalue(L, 2);
lua_settable(L, -3);
SSL_CTX_set_psk_client_callback(ctx->context, client_psk_cb);
lua_pushboolean(L, 1);
return 1;
}
#if defined(LSEC_ENABLE_DANE)
/*
* DANE
@ -759,6 +877,8 @@ static luaL_Reg funcs[] = {
{"setdhparam", set_dhparam},
{"setverify", set_verify},
{"setoptions", set_options},
{"setserverpskcb", set_server_psk_cb},
{"setclientpskcb", set_client_psk_cb},
{"setmode", set_mode},
#if !defined(OPENSSL_NO_EC)
{"setcurve", set_curve},
@ -792,6 +912,14 @@ static int meth_destroy(lua_State *L)
lua_pushlightuserdata(L, (void*)ctx->context);
lua_pushnil(L);
lua_settable(L, -3);
luaL_getmetatable(L, "SSL:PSKServer:Registry");
lua_pushlightuserdata(L, (void*)ctx->context);
lua_pushnil(L);
lua_settable(L, -3);
luaL_getmetatable(L, "SSL:PSKClient:Registry");
lua_pushlightuserdata(L, (void*)ctx->context);
lua_pushnil(L);
lua_settable(L, -3);
SSL_CTX_free(ctx->context);
ctx->context = NULL;
@ -934,9 +1062,11 @@ void *lsec_testudata (lua_State *L, int ud, const char *tname) {
*/
LSEC_API int luaopen_ssl_context(lua_State *L)
{
luaL_newmetatable(L, "SSL:DH:Registry"); /* Keep all DH callbacks */
luaL_newmetatable(L, "SSL:ALPN:Registry"); /* Keep all ALPN callbacks */
luaL_newmetatable(L, "SSL:Verify:Registry"); /* Keep all verify flags */
luaL_newmetatable(L, "SSL:DH:Registry"); /* Keep all DH callbacks */
luaL_newmetatable(L, "SSL:ALPN:Registry"); /* Keep all ALPN callbacks */
luaL_newmetatable(L, "SSL:PSKServer:Registry"); /* Keep all PSK callbacks */
luaL_newmetatable(L, "SSL:PSKClient:Registry"); /* Keep all PSK callbacks */
luaL_newmetatable(L, "SSL:Verify:Registry"); /* Keep all verify flags */
luaL_newmetatable(L, "SSL:Context");
setfuncs(L, meta);

View File

@ -201,6 +201,29 @@ local function newcontext(cfg)
if not succ then return nil, msg end
end
-- PSK
if cfg.psk then
if type(cfg.psk) == "function" then
local pskcb = cfg.psk
if cfg.mode == "client" then
succ, msg = context.setclientpskcb(ctx, function(hint, max_psk_len)
local identity, psk = pskcb(hint, max_psk_len)
return identity, psk
end)
if not succ then return nil, msg end
else
succ, msg = context.setserverpskcb(ctx, function(identity, max_psk_len)
local psk = pskcb(identity, max_psk_len)
return psk
end)
if not succ then return nil, msg end
end
else
return nil, "invalid PSK Callback parameter"
end
end
if config.capabilities.dane and cfg.dane then
if type(cfg.dane) == "table" then
context.setdane(ctx, unpack(cfg.dane))