diff --git a/samples/alpn/client.lua b/samples/alpn/client.lua new file mode 100644 index 0000000..dc55452 --- /dev/null +++ b/samples/alpn/client.lua @@ -0,0 +1,27 @@ +-- +-- Public domain +-- +local socket = require("socket") +local ssl = require("ssl") + +local params = { + mode = "client", + protocol = "tlsv1_2", + key = "../certs/clientAkey.pem", + certificate = "../certs/clientA.pem", + cafile = "../certs/rootA.pem", + verify = {"peer", "fail_if_no_peer_cert"}, + options = "all", + --alpn = {"foo","bar","baz"} + alpn = "foo" +} + +local peer = socket.tcp() +peer:connect("127.0.0.1", 8888) + +peer = assert( ssl.wrap(peer, params) ) +assert(peer:dohandshake()) + +print("ALPN", peer:getalpn()) + +peer:close() diff --git a/samples/alpn/server.lua b/samples/alpn/server.lua new file mode 100644 index 0000000..08992e5 --- /dev/null +++ b/samples/alpn/server.lua @@ -0,0 +1,77 @@ +-- +-- Public domain +-- +local socket = require("socket") +local ssl = require("ssl") + +-- +-- Callback that selects one protocol from client's list. +-- +local function alpncb01(protocols) + print("--- ALPN protocols from client") + for k, v in ipairs(protocols) do + print(k, v) + end + print("--- Selecting:", protocols[1]) + return protocols[1] +end + +-- +-- Callback that returns a fixed list, ignoring the client's list. +-- +local function alpncb02(protocols) + print("--- ALPN protocols from client") + for k, v in ipairs(protocols) do + print(k, v) + end + print("--- Returning a fixed list") + return {"bar", "foo"} +end + +-- +-- Callback that generates a list as it whishes. +-- +local function alpncb03(protocols) + local resp = {} + print("--- ALPN protocols from client") + for k, v in ipairs(protocols) do + print(k, v) + if k%2 ~= 0 then resp[#resp+1] = v end + end + print("--- Returning an odd list") + return resp +end + + +local params = { + mode = "server", + protocol = "any", + key = "../certs/serverAkey.pem", + certificate = "../certs/serverA.pem", + cafile = "../certs/rootA.pem", + verify = {"peer", "fail_if_no_peer_cert"}, + options = "all", + --alpn = alpncb01, + --alpn = alpncb02, + --alpn = alpncb03, + alpn = {"bar", "baz", "foo"}, +} + + +-- [[ SSL context +local ctx = assert(ssl.newcontext(params)) +--]] + +local server = socket.tcp() +server:setoption('reuseaddr', true) +assert( server:bind("127.0.0.1", 8888) ) +server:listen() + +local peer = server:accept() +peer = assert( ssl.wrap(peer, ctx) ) +assert( peer:dohandshake() ) + +print("ALPN", peer:getalpn()) + +peer:close() +server:close() diff --git a/src/config.c b/src/config.c index 6939fca..6b68f46 100644 --- a/src/config.c +++ b/src/config.c @@ -65,6 +65,11 @@ LSEC_API int luaopen_ssl_config(lua_State *L) lua_pushstring(L, "capabilities"); lua_newtable(L); + // ALPN + lua_pushstring(L, "alpn"); + lua_pushboolean(L, 1); + lua_rawset(L, -3); + #ifndef OPENSSL_NO_EC #if defined(SSL_CTRL_SET_ECDH_AUTO) || defined(SSL_CTRL_SET_CURVES_LIST) || defined(SSL_CTX_set1_curves_list) lua_pushstring(L, "curves_list"); diff --git a/src/context.c b/src/context.c index 3186f34..2033004 100644 --- a/src/context.c +++ b/src/context.c @@ -609,6 +609,91 @@ static int set_curves_list(lua_State *L) } #endif +/** + * Set the protocols a client should send for ALPN. + */ +static int set_alpn(lua_State *L) +{ + long ret; + size_t len; + p_context ctx = checkctx(L, 1); + const char *str = luaL_checklstring(L, 2, &len); + + ret = SSL_CTX_set_alpn_protos(ctx->context, (const unsigned char*)str, len); + if (ret) { + lua_pushboolean(L, 0); + lua_pushfstring(L, "error setting ALPN (%s)", ERR_reason_error_string(ERR_get_error())); + return 2; + } + lua_pushboolean(L, 1); + return 1; +} + +/** + * This standard callback calls the server's callback in Lua sapce. + * The server has to return a list in wire-format strings. + * This function uses a helper function to match server and client lists. + */ +static int alpn_cb(SSL *s, const unsigned char **out, unsigned char *outlen, + const unsigned char *in, unsigned int inlen, void *arg) +{ + int ret; + size_t server_len; + const char *server; + p_context ctx = (p_context)arg; + lua_State *L = ctx->L; + + luaL_getmetatable(L, "SSL:ALPN:Registry"); + lua_pushlightuserdata(L, (void*)ctx->context); + lua_gettable(L, -2); + + lua_pushlstring(L, (const char*)in, inlen); + + lua_call(L, 1, 1); + + if (!lua_isstring(L, -1)) { + lua_pop(L, 2); + return SSL_TLSEXT_ERR_NOACK; + } + + // Protocol list from server in wire-format string + server = luaL_checklstring(L, -1, &server_len); + ret = SSL_select_next_proto((unsigned char**)out, outlen, (const unsigned char*)server, + server_len, in, inlen); + if (ret != OPENSSL_NPN_NEGOTIATED) { + lua_pop(L, 2); + return SSL_TLSEXT_ERR_NOACK; + } + + // Copy the result because lua_pop() can collect the pointer + ctx->alpn = malloc(*outlen); + memcpy(ctx->alpn, (void*)*out, *outlen); + *out = (const unsigned char*)ctx->alpn; + + lua_pop(L, 2); + + return SSL_TLSEXT_ERR_OK; +} + +/** + * Set a callback a server can use to select the next protocol with ALPN. + */ +static int set_alpn_cb(lua_State *L) +{ + p_context ctx = checkctx(L, 1); + + luaL_getmetatable(L, "SSL:ALPN:Registry"); + lua_pushlightuserdata(L, (void*)ctx->context); + lua_pushvalue(L, 2); + lua_settable(L, -3); + + SSL_CTX_set_alpn_select_cb(ctx->context, alpn_cb, ctx); + + lua_pushboolean(L, 1); + return 1; +} + + /** * Package functions */ @@ -618,6 +703,8 @@ static luaL_Reg funcs[] = { {"loadcert", load_cert}, {"loadkey", load_key}, {"checkkey", check_key}, + {"setalpn", set_alpn}, + {"setalpncb", set_alpn_cb}, {"setcipher", set_cipher}, {"setdepth", set_depth}, {"setdhparam", set_dhparam}, @@ -654,6 +741,10 @@ static int meth_destroy(lua_State *L) lua_pushlightuserdata(L, (void*)ctx->context); lua_pushnil(L); lua_settable(L, -3); + luaL_getmetatable(L, "SSL:ALPN:Registry"); + lua_pushlightuserdata(L, (void*)ctx->context); + lua_pushnil(L); + lua_settable(L, -3); SSL_CTX_free(ctx->context); ctx->context = NULL; @@ -795,8 +886,9 @@ void *lsec_testudata (lua_State *L, int ud, const char *tname) { */ LSEC_API int luaopen_ssl_context(lua_State *L) { - luaL_newmetatable(L, "SSL:DH:Registry"); /* Keep all DH callbacks */ - luaL_newmetatable(L, "SSL:Verify:Registry"); /* Keep all verify flags */ + luaL_newmetatable(L, "SSL:DH:Registry"); /* Keep all DH callbacks */ + luaL_newmetatable(L, "SSL:ALPN:Registry"); /* Keep all ALPN callbacks */ + luaL_newmetatable(L, "SSL:Verify:Registry"); /* Keep all verify flags */ luaL_newmetatable(L, "SSL:Context"); setfuncs(L, meta); diff --git a/src/context.h b/src/context.h index a971550..1ffe3ba 100644 --- a/src/context.h +++ b/src/context.h @@ -24,6 +24,7 @@ typedef struct t_context_ { SSL_CTX *context; lua_State *L; DH *dh_param; + void *alpn; int mode; } t_context; typedef t_context* p_context; diff --git a/src/ssl.c b/src/ssl.c index 45d143d..4dd4686 100644 --- a/src/ssl.c +++ b/src/ssl.c @@ -389,6 +389,10 @@ static int meth_handshake(lua_State *L) DH_free(ctx->dh_param); ctx->dh_param = NULL; } + if (ctx->alpn) { + free(ctx->alpn); + ctx->alpn = NULL; + } if (err == IO_DONE) { lua_pushboolean(L, 1); return 1; @@ -799,6 +803,19 @@ static int meth_getsniname(lua_State *L) return 1; } +static int meth_getalpn(lua_State *L) +{ + unsigned len; + const unsigned char *data; + p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection"); + SSL_get0_alpn_selected(ssl->ssl, &data, &len); + if (data == NULL && len == 0) + lua_pushnil(L); + else + lua_pushlstring(L, (const char*)data, len); + return 1; +} + static int meth_copyright(lua_State *L) { lua_pushstring(L, "LuaSec 0.7 - Copyright (C) 2006-2018 Bruno Silvestre, UFG" @@ -816,6 +833,7 @@ static int meth_copyright(lua_State *L) */ static luaL_Reg methods[] = { {"close", meth_close}, + {"getalpn", meth_getalpn}, {"getfd", meth_getfd}, {"getfinished", meth_getfinished}, {"getpeercertificate", meth_getpeercertificate}, diff --git a/src/ssl.lua b/src/ssl.lua index 3bd236b..d5fbd59 100644 --- a/src/ssl.lua +++ b/src/ssl.lua @@ -30,6 +30,39 @@ local function optexec(func, param, ctx) return true end +-- +-- Convert an array of strings to wire-format +-- +local function array2wireformat(array) + local str = "" + for k, v in ipairs(array) do + if type(v) ~= "string" then return nil end + local len = #v + if len == 0 then + return nil, "invalid ALPN name (empty string)" + elseif len > 255 then + return nil, "invalid ALPN name (length > 255)" + end + str = str .. string.char(len) .. v + end + if str == "" then return nil, "invalid ALPN list (empty)" end + return str +end + +-- +-- Convert wire-string format to array +-- +local function wireformat2array(str) + local i = 1 + local array = {} + while i < #str do + local len = str:byte(i) + array[#array + 1] = str:sub(i + 1, i + len) + i = i + len + 1 + end + return array +end + -- -- -- @@ -113,6 +146,48 @@ local function newcontext(cfg) if not succ then return nil, msg end end + -- ALPN + if cfg.mode == "server" and cfg.alpn then + if type(cfg.alpn) == "function" then + local alpncb = cfg.alpn + -- This callback function has to return one value only + succ, msg = context.setalpncb(ctx, function(str) + local protocols = alpncb(wireformat2array(str)) + if type(protocols) == "string" then + protocols = { protocols } + elseif type(protocols) ~= "table" then + return nil + end + return (array2wireformat(protocols)) -- use "()" to drop error message + end) + if not succ then return nil, msg end + elseif type(cfg.alpn) == "table" then + local protocols = cfg.alpn + -- check if array is valid before use it + succ, msg = array2wireformat(protocols) + if not succ then return nil, msg end + -- This callback function has to return one value only + succ, msg = context.setalpncb(ctx, function() + return (array2wireformat(protocols)) -- use "()" to drop error message + end) + if not succ then return nil, msg end + else + return nil, "invalid ALPN parameter" + end + elseif cfg.mode == "client" and cfg.alpn then + local alpn + if type(cfg.alpn) == "string" then + alpn, msg = array2wireformat({ cfg.alpn }) + elseif type(cfg.alpn) == "table" then + alpn, msg = array2wireformat(cfg.alpn) + else + return nil, "invalid ALPN parameter" + end + if not alpn then return nil, msg end + succ, msg = context.setalpn(ctx, alpn) + if not succ then return nil, msg end + end + return ctx end