diff --git a/samples/README b/samples/README index 517eefe..bd983d4 100644 --- a/samples/README +++ b/samples/README @@ -45,6 +45,9 @@ Directories: * oneshot A simple connection example. +* psk + PSK(Pre Shared Key) support. + * sni Support to SNI (Server Name Indication). diff --git a/samples/psk/client.lua b/samples/psk/client.lua new file mode 100644 index 0000000..75eeb63 --- /dev/null +++ b/samples/psk/client.lua @@ -0,0 +1,36 @@ +-- +-- Public domain +-- +local socket = require("socket") +local ssl = require("ssl") + +-- @param hint (nil | string) +-- @param max_identity_len (number) +-- @param max_psk_len (number) +-- @return identity (string) +-- @return PSK (string) +local function pskcb(hint, max_identity_len, max_psk_len) + print(string.format("PSK Callback: hint=%q, max_identity_len=%d, max_psk_len=%d", hint, max_identity_len, max_psk_len)) + return "abcd", "1234" +end + +local params = { + mode = "client", + protocol = "tlsv1_2", + psk = pskcb, +} + +local peer = socket.tcp() +peer:connect("127.0.0.1", 8888) + +peer = assert( ssl.wrap(peer, params) ) +assert(peer:dohandshake()) + +print("--- INFO ---") +local info = peer:info() +for k, v in pairs(info) do + print(k, v) +end +print("---") + +peer:close() diff --git a/samples/psk/server.lua b/samples/psk/server.lua new file mode 100644 index 0000000..b5a958b --- /dev/null +++ b/samples/psk/server.lua @@ -0,0 +1,55 @@ +-- +-- Public domain +-- +local socket = require("socket") +local ssl = require("ssl") + +-- @param identity (string) +-- @param max_psk_len (number) +-- @return psk (string) +local function pskcb(identity, max_psk_len) + print(string.format("PSK Callback: identity=%q, max_psk_len=%d", identity, max_psk_len)) + if identity == "abcd" then + return "1234" + end + return nil +end + +local params = { + mode = "server", + protocol = "any", + options = "all", + +-- PSK with just a callback + psk = pskcb, + +-- PSK with identity hint +-- psk = { +-- hint = "hintpsksample", +-- callback = pskcb, +-- }, +} + + +-- [[ SSL context +local ctx = assert(ssl.newcontext(params)) +--]] + +local server = socket.tcp() +server:setoption('reuseaddr', true) +assert( server:bind("127.0.0.1", 8888) ) +server:listen() + +local peer = server:accept() +peer = assert( ssl.wrap(peer, ctx) ) +assert( peer:dohandshake() ) + +print("--- INFO ---") +local info = peer:info() +for k, v in pairs(info) do + print(k, v) +end +print("---") + +peer:close() +server:close() diff --git a/src/context.c b/src/context.c index 0bb4826..dcabbc0 100644 --- a/src/context.c +++ b/src/context.c @@ -708,6 +708,141 @@ static int set_alpn_cb(lua_State *L) return 1; } +/** + * Callback to select the PSK. + */ +static unsigned int server_psk_cb(SSL *ssl, const char *identity, unsigned char *psk, + unsigned int max_psk_len) +{ + size_t psk_len; + const char *ret_psk; + SSL_CTX *ctx = SSL_get_SSL_CTX(ssl); + p_context pctx = (p_context)SSL_CTX_get_app_data(ctx); + lua_State *L = pctx->L; + + luaL_getmetatable(L, "SSL:PSK:Registry"); + lua_pushlightuserdata(L, (void*)pctx->context); + lua_gettable(L, -2); + + lua_pushstring(L, identity); + lua_pushinteger(L, max_psk_len); + + lua_call(L, 2, 1); + + if (!lua_isstring(L, -1)) { + lua_pop(L, 2); + return 0; + } + + ret_psk = lua_tolstring(L, -1, &psk_len); + + if (psk_len == 0 || psk_len > max_psk_len) + psk_len = 0; + else + memcpy(psk, ret_psk, psk_len); + + lua_pop(L, 2); + + return psk_len; +} + +/** + * Set a PSK callback for server. + */ +static int set_server_psk_cb(lua_State *L) +{ + p_context ctx = checkctx(L, 1); + + luaL_getmetatable(L, "SSL:PSK:Registry"); + lua_pushlightuserdata(L, (void*)ctx->context); + lua_pushvalue(L, 2); + lua_settable(L, -3); + + SSL_CTX_set_psk_server_callback(ctx->context, server_psk_cb); + + lua_pushboolean(L, 1); + return 1; +} + +/* + * Set the PSK indentity hint. + */ +static int set_psk_identity_hint(lua_State *L) +{ + p_context ctx = checkctx(L, 1); + const char *hint = luaL_checkstring(L, 2); + int ret = SSL_CTX_use_psk_identity_hint(ctx->context, hint); + lua_pushboolean(L, ret); + return 1; +} + +/* + * Client callback to PSK. + */ +static unsigned int client_psk_cb(SSL *ssl, const char *hint, char *identity, + unsigned int max_identity_len, unsigned char *psk, unsigned int max_psk_len) +{ + size_t psk_len; + size_t identity_len; + const char *ret_psk; + const char *ret_identity; + SSL_CTX *ctx = SSL_get_SSL_CTX(ssl); + p_context pctx = (p_context)SSL_CTX_get_app_data(ctx); + lua_State *L = pctx->L; + + luaL_getmetatable(L, "SSL:PSK:Registry"); + lua_pushlightuserdata(L, (void*)pctx->context); + lua_gettable(L, -2); + + if (hint) + lua_pushstring(L, hint); + else + lua_pushnil(L); + + // Leave space to '\0' + lua_pushinteger(L, max_identity_len-1); + lua_pushinteger(L, max_psk_len); + + lua_call(L, 3, 2); + + if (!lua_isstring(L, -1) || !lua_isstring(L, -2)) { + lua_pop(L, 3); + return 0; + } + + ret_identity = lua_tolstring(L, -2, &identity_len); + ret_psk = lua_tolstring(L, -1, &psk_len); + + if (identity_len >= max_identity_len || psk_len > max_psk_len) + psk_len = 0; + else { + memcpy(identity, ret_identity, identity_len); + identity[identity_len] = 0; + memcpy(psk, ret_psk, psk_len); + } + + lua_pop(L, 3); + + return psk_len; +} + +/** + * Set a PSK callback for client. + */ +static int set_client_psk_cb(lua_State *L) { + p_context ctx = checkctx(L, 1); + + luaL_getmetatable(L, "SSL:PSK:Registry"); + lua_pushlightuserdata(L, (void*)ctx->context); + lua_pushvalue(L, 2); + lua_settable(L, -3); + + SSL_CTX_set_psk_client_callback(ctx->context, client_psk_cb); + + lua_pushboolean(L, 1); + return 1; +} + #if defined(LSEC_ENABLE_DANE) /* * DANE @@ -759,6 +894,9 @@ static luaL_Reg funcs[] = { {"setdhparam", set_dhparam}, {"setverify", set_verify}, {"setoptions", set_options}, + {"setpskhint", set_psk_identity_hint}, + {"setserverpskcb", set_server_psk_cb}, + {"setclientpskcb", set_client_psk_cb}, {"setmode", set_mode}, #if !defined(OPENSSL_NO_EC) {"setcurve", set_curve}, @@ -792,6 +930,10 @@ static int meth_destroy(lua_State *L) lua_pushlightuserdata(L, (void*)ctx->context); lua_pushnil(L); lua_settable(L, -3); + luaL_getmetatable(L, "SSL:PSK:Registry"); + lua_pushlightuserdata(L, (void*)ctx->context); + lua_pushnil(L); + lua_settable(L, -3); SSL_CTX_free(ctx->context); ctx->context = NULL; @@ -934,9 +1076,10 @@ void *lsec_testudata (lua_State *L, int ud, const char *tname) { */ LSEC_API int luaopen_ssl_context(lua_State *L) { - luaL_newmetatable(L, "SSL:DH:Registry"); /* Keep all DH callbacks */ - luaL_newmetatable(L, "SSL:ALPN:Registry"); /* Keep all ALPN callbacks */ - luaL_newmetatable(L, "SSL:Verify:Registry"); /* Keep all verify flags */ + luaL_newmetatable(L, "SSL:DH:Registry"); /* Keep all DH callbacks */ + luaL_newmetatable(L, "SSL:ALPN:Registry"); /* Keep all ALPN callbacks */ + luaL_newmetatable(L, "SSL:PSK:Registry"); /* Keep all PSK callbacks */ + luaL_newmetatable(L, "SSL:Verify:Registry"); /* Keep all verify flags */ luaL_newmetatable(L, "SSL:Context"); setfuncs(L, meta); diff --git a/src/ssl.lua b/src/ssl.lua index f49b172..b182d53 100644 --- a/src/ssl.lua +++ b/src/ssl.lua @@ -201,6 +201,33 @@ local function newcontext(cfg) if not succ then return nil, msg end end + -- PSK + if cfg.psk then + if cfg.mode == "client" then + if type(cfg.psk) ~= "function" then + return nil, "invalid PSK configuration" + end + succ = context.setclientpskcb(ctx, cfg.psk) + if not succ then return nil, msg end + elseif cfg.mode == "server" then + if type(cfg.psk) == "function" then + succ, msg = context.setserverpskcb(ctx, cfg.psk) + if not succ then return nil, msg end + elseif type(cfg.psk) == "table" then + if type(cfg.psk.hint) == "string" and type(cfg.psk.callback) == "function" then + succ, msg = context.setpskhint(ctx, cfg.psk.hint) + if not succ then return succ, msg end + succ = context.setserverpskcb(ctx, cfg.psk.callback) + if not succ then return succ, msg end + else + return nil, "invalid PSK configuration" + end + else + return nil, "invalid PSK configuration" + end + end + end + if config.capabilities.dane and cfg.dane then if type(cfg.dane) == "table" then context.setdane(ctx, unpack(cfg.dane))