From 842380caf6f51c3420e080c90cb547e27cd76bd3 Mon Sep 17 00:00:00 2001 From: unknown Date: Thu, 16 Feb 2023 09:52:18 +0900 Subject: [PATCH] feat: tls-psk --- samples/README | 3 + samples/psk/client.lua | 29 +++++++++ samples/psk/server.lua | 42 +++++++++++++ src/context.c | 136 ++++++++++++++++++++++++++++++++++++++++- src/ssl.lua | 23 +++++++ 5 files changed, 230 insertions(+), 3 deletions(-) create mode 100644 samples/psk/client.lua create mode 100644 samples/psk/server.lua diff --git a/samples/README b/samples/README index 517eefe..bd983d4 100644 --- a/samples/README +++ b/samples/README @@ -45,6 +45,9 @@ Directories: * oneshot A simple connection example. +* psk + PSK(Pre Shared Key) support. + * sni Support to SNI (Server Name Indication). diff --git a/samples/psk/client.lua b/samples/psk/client.lua new file mode 100644 index 0000000..b496a5e --- /dev/null +++ b/samples/psk/client.lua @@ -0,0 +1,29 @@ +-- +-- Public domain +-- +local socket = require("socket") +local ssl = require("ssl") + +local params = { + mode = "client", + protocol = "tlsv1_2", + psk = function(hint, max_psk_len) + print("PSK Callback: hint=", hint, ", max_psk_len=", max_psk_len) + return "abcd", "1234" + end +} + +local peer = socket.tcp() +peer:connect("127.0.0.1", 8888) + +peer = assert( ssl.wrap(peer, params) ) +assert(peer:dohandshake()) + +print("--- INFO ---") +local info = peer:info() +for k, v in pairs(info) do + print(k, v) +end +print("---") + +peer:close() diff --git a/samples/psk/server.lua b/samples/psk/server.lua new file mode 100644 index 0000000..aa9aff8 --- /dev/null +++ b/samples/psk/server.lua @@ -0,0 +1,42 @@ +-- +-- Public domain +-- +local socket = require("socket") +local ssl = require("ssl") + +local params = { + mode = "server", + protocol = "any", + options = "all", + psk = function(identity, max_psk_len) + print("PSK Callback: identity=", identity, ", max_psk_len=", max_psk_len) + if identity == "abcd" then + return "1234" + end + return nil + end +} + + +-- [[ SSL context +local ctx = assert(ssl.newcontext(params)) +--]] + +local server = socket.tcp() +server:setoption('reuseaddr', true) +assert( server:bind("127.0.0.1", 8888) ) +server:listen() + +local peer = server:accept() +peer = assert( ssl.wrap(peer, ctx) ) +assert( peer:dohandshake() ) + +print("--- INFO ---") +local info = peer:info() +for k, v in pairs(info) do + print(k, v) +end +print("---") + +peer:close() +server:close() diff --git a/src/context.c b/src/context.c index 0bb4826..8937fbb 100644 --- a/src/context.c +++ b/src/context.c @@ -708,6 +708,124 @@ static int set_alpn_cb(lua_State *L) return 1; } +static unsigned int server_psk_cb( + SSL *ssl, + const char *identity, + unsigned char *psk, + unsigned int max_psk_len +) +{ + int psk_len; + const char *ret_psk; + SSL_CTX *ctx = SSL_get_SSL_CTX(ssl); + p_context pctx = (p_context)SSL_CTX_get_app_data(ctx); + lua_State *L = pctx->L; + + luaL_getmetatable(L, "SSL:PSKServer:Registry"); + lua_pushlightuserdata(L, (void*)pctx->context); + lua_gettable(L, -2); + + lua_pushlstring(L, identity, strlen(identity)); + lua_pushnumber(L, max_psk_len, max_psk_len); + + lua_call(L, 2, 1); + + if (!lua_isstring(L, -1)) { + lua_pop(L, 2); + return 0; + } + + ret_psk = luaL_checklstring(L, -1, &psk_len); + + if (psk_len > max_psk_len) psk_len = max_psk_len; + + memcpy(psk, ret_psk, psk_len); + + lua_pop(L, 2); + + return psk_len; +} + +/** + * Set a psk callback for server. + */ +static int set_server_psk_cb(lua_State *L) { + p_context ctx = checkctx(L, 1); + + luaL_getmetatable(L, "SSL:PSKServer:Registry"); + lua_pushlightuserdata(L, (void*)ctx->context); + lua_pushvalue(L, 2); + lua_settable(L, -3); + + SSL_CTX_set_psk_server_callback(ctx->context, server_psk_cb); + + lua_pushboolean(L, 1); + return 1; +} + +static unsigned int client_psk_cb( + SSL *ssl, const char *hint, + char *identity, unsigned int max_identity_len, + unsigned char *psk, unsigned int max_psk_len +) +{ + int identity_len; + const char *ret_identity; + int psk_len; + const char *ret_psk; + SSL_CTX *ctx = SSL_get_SSL_CTX(ssl); + p_context pctx = (p_context)SSL_CTX_get_app_data(ctx); + lua_State *L = pctx->L; + + luaL_getmetatable(L, "SSL:PSKClient:Registry"); + lua_pushlightuserdata(L, (void*)pctx->context); + lua_gettable(L, -2); + + if (hint) { + lua_pushlstring(L, hint, strlen(hint)); + } else { + lua_pushlstring(L, "", 0); + } + lua_pushnumber(L, max_psk_len, max_psk_len); + + lua_call(L, 2, 2); + + if (!lua_isstring(L, -1) || !lua_isstring(L, -2)) { + lua_pop(L, 2); + return 0; + } + + ret_identity = luaL_checklstring(L, -2, &identity_len); + ret_psk = luaL_checklstring(L, -1, &psk_len); + + if (identity_len > max_identity_len) identity_len = max_identity_len; + if (psk_len > max_psk_len) psk_len = max_psk_len; + + memcpy(identity, ret_identity, identity_len); + memcpy(psk, ret_psk, psk_len); + + lua_pop(L, 2); + + return psk_len; +} + +/** + * Set a psk callback for client. + */ +static int set_client_psk_cb(lua_State *L) { + p_context ctx = checkctx(L, 1); + + luaL_getmetatable(L, "SSL:PSKClient:Registry"); + lua_pushlightuserdata(L, (void*)ctx->context); + lua_pushvalue(L, 2); + lua_settable(L, -3); + + SSL_CTX_set_psk_client_callback(ctx->context, client_psk_cb); + + lua_pushboolean(L, 1); + return 1; +} + #if defined(LSEC_ENABLE_DANE) /* * DANE @@ -759,6 +877,8 @@ static luaL_Reg funcs[] = { {"setdhparam", set_dhparam}, {"setverify", set_verify}, {"setoptions", set_options}, + {"setserverpskcb", set_server_psk_cb}, + {"setclientpskcb", set_client_psk_cb}, {"setmode", set_mode}, #if !defined(OPENSSL_NO_EC) {"setcurve", set_curve}, @@ -792,6 +912,14 @@ static int meth_destroy(lua_State *L) lua_pushlightuserdata(L, (void*)ctx->context); lua_pushnil(L); lua_settable(L, -3); + luaL_getmetatable(L, "SSL:PSKServer:Registry"); + lua_pushlightuserdata(L, (void*)ctx->context); + lua_pushnil(L); + lua_settable(L, -3); + luaL_getmetatable(L, "SSL:PSKClient:Registry"); + lua_pushlightuserdata(L, (void*)ctx->context); + lua_pushnil(L); + lua_settable(L, -3); SSL_CTX_free(ctx->context); ctx->context = NULL; @@ -934,9 +1062,11 @@ void *lsec_testudata (lua_State *L, int ud, const char *tname) { */ LSEC_API int luaopen_ssl_context(lua_State *L) { - luaL_newmetatable(L, "SSL:DH:Registry"); /* Keep all DH callbacks */ - luaL_newmetatable(L, "SSL:ALPN:Registry"); /* Keep all ALPN callbacks */ - luaL_newmetatable(L, "SSL:Verify:Registry"); /* Keep all verify flags */ + luaL_newmetatable(L, "SSL:DH:Registry"); /* Keep all DH callbacks */ + luaL_newmetatable(L, "SSL:ALPN:Registry"); /* Keep all ALPN callbacks */ + luaL_newmetatable(L, "SSL:PSKServer:Registry"); /* Keep all PSK callbacks */ + luaL_newmetatable(L, "SSL:PSKClient:Registry"); /* Keep all PSK callbacks */ + luaL_newmetatable(L, "SSL:Verify:Registry"); /* Keep all verify flags */ luaL_newmetatable(L, "SSL:Context"); setfuncs(L, meta); diff --git a/src/ssl.lua b/src/ssl.lua index f49b172..adda272 100644 --- a/src/ssl.lua +++ b/src/ssl.lua @@ -201,6 +201,29 @@ local function newcontext(cfg) if not succ then return nil, msg end end + -- PSK + if cfg.psk then + if type(cfg.psk) == "function" then + local pskcb = cfg.psk + + if cfg.mode == "client" then + succ, msg = context.setclientpskcb(ctx, function(hint, max_psk_len) + local identity, psk = pskcb(hint, max_psk_len) + return identity, psk + end) + if not succ then return nil, msg end + else + succ, msg = context.setserverpskcb(ctx, function(identity, max_psk_len) + local psk = pskcb(identity, max_psk_len) + return psk + end) + if not succ then return nil, msg end + end + else + return nil, "invalid PSK Callback parameter" + end + end + if config.capabilities.dane and cfg.dane then if type(cfg.dane) == "table" then context.setdane(ctx, unpack(cfg.dane))