mirror of
https://github.com/brunoos/luasec.git
synced 2024-12-27 12:58:21 +01:00
Add ALPN support based on PR #64 from xnyhps
This commit is contained in:
parent
fdb2fa5f59
commit
dea60edf4f
27
samples/alpn/client.lua
Normal file
27
samples/alpn/client.lua
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
--
|
||||||
|
-- Public domain
|
||||||
|
--
|
||||||
|
local socket = require("socket")
|
||||||
|
local ssl = require("ssl")
|
||||||
|
|
||||||
|
local params = {
|
||||||
|
mode = "client",
|
||||||
|
protocol = "tlsv1_2",
|
||||||
|
key = "../certs/clientAkey.pem",
|
||||||
|
certificate = "../certs/clientA.pem",
|
||||||
|
cafile = "../certs/rootA.pem",
|
||||||
|
verify = {"peer", "fail_if_no_peer_cert"},
|
||||||
|
options = "all",
|
||||||
|
--alpn = {"foo","bar","baz"}
|
||||||
|
alpn = "foo"
|
||||||
|
}
|
||||||
|
|
||||||
|
local peer = socket.tcp()
|
||||||
|
peer:connect("127.0.0.1", 8888)
|
||||||
|
|
||||||
|
peer = assert( ssl.wrap(peer, params) )
|
||||||
|
assert(peer:dohandshake())
|
||||||
|
|
||||||
|
print("ALPN", peer:getalpn())
|
||||||
|
|
||||||
|
peer:close()
|
77
samples/alpn/server.lua
Normal file
77
samples/alpn/server.lua
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
--
|
||||||
|
-- Public domain
|
||||||
|
--
|
||||||
|
local socket = require("socket")
|
||||||
|
local ssl = require("ssl")
|
||||||
|
|
||||||
|
--
|
||||||
|
-- Callback that selects one protocol from client's list.
|
||||||
|
--
|
||||||
|
local function alpncb01(protocols)
|
||||||
|
print("--- ALPN protocols from client")
|
||||||
|
for k, v in ipairs(protocols) do
|
||||||
|
print(k, v)
|
||||||
|
end
|
||||||
|
print("--- Selecting:", protocols[1])
|
||||||
|
return protocols[1]
|
||||||
|
end
|
||||||
|
|
||||||
|
--
|
||||||
|
-- Callback that returns a fixed list, ignoring the client's list.
|
||||||
|
--
|
||||||
|
local function alpncb02(protocols)
|
||||||
|
print("--- ALPN protocols from client")
|
||||||
|
for k, v in ipairs(protocols) do
|
||||||
|
print(k, v)
|
||||||
|
end
|
||||||
|
print("--- Returning a fixed list")
|
||||||
|
return {"bar", "foo"}
|
||||||
|
end
|
||||||
|
|
||||||
|
--
|
||||||
|
-- Callback that generates a list as it whishes.
|
||||||
|
--
|
||||||
|
local function alpncb03(protocols)
|
||||||
|
local resp = {}
|
||||||
|
print("--- ALPN protocols from client")
|
||||||
|
for k, v in ipairs(protocols) do
|
||||||
|
print(k, v)
|
||||||
|
if k%2 ~= 0 then resp[#resp+1] = v end
|
||||||
|
end
|
||||||
|
print("--- Returning an odd list")
|
||||||
|
return resp
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
|
local params = {
|
||||||
|
mode = "server",
|
||||||
|
protocol = "any",
|
||||||
|
key = "../certs/serverAkey.pem",
|
||||||
|
certificate = "../certs/serverA.pem",
|
||||||
|
cafile = "../certs/rootA.pem",
|
||||||
|
verify = {"peer", "fail_if_no_peer_cert"},
|
||||||
|
options = "all",
|
||||||
|
--alpn = alpncb01,
|
||||||
|
--alpn = alpncb02,
|
||||||
|
--alpn = alpncb03,
|
||||||
|
alpn = {"bar", "baz", "foo"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
-- [[ SSL context
|
||||||
|
local ctx = assert(ssl.newcontext(params))
|
||||||
|
--]]
|
||||||
|
|
||||||
|
local server = socket.tcp()
|
||||||
|
server:setoption('reuseaddr', true)
|
||||||
|
assert( server:bind("127.0.0.1", 8888) )
|
||||||
|
server:listen()
|
||||||
|
|
||||||
|
local peer = server:accept()
|
||||||
|
peer = assert( ssl.wrap(peer, ctx) )
|
||||||
|
assert( peer:dohandshake() )
|
||||||
|
|
||||||
|
print("ALPN", peer:getalpn())
|
||||||
|
|
||||||
|
peer:close()
|
||||||
|
server:close()
|
@ -65,6 +65,11 @@ LSEC_API int luaopen_ssl_config(lua_State *L)
|
|||||||
lua_pushstring(L, "capabilities");
|
lua_pushstring(L, "capabilities");
|
||||||
lua_newtable(L);
|
lua_newtable(L);
|
||||||
|
|
||||||
|
// ALPN
|
||||||
|
lua_pushstring(L, "alpn");
|
||||||
|
lua_pushboolean(L, 1);
|
||||||
|
lua_rawset(L, -3);
|
||||||
|
|
||||||
#ifndef OPENSSL_NO_EC
|
#ifndef OPENSSL_NO_EC
|
||||||
#if defined(SSL_CTRL_SET_ECDH_AUTO) || defined(SSL_CTRL_SET_CURVES_LIST) || defined(SSL_CTX_set1_curves_list)
|
#if defined(SSL_CTRL_SET_ECDH_AUTO) || defined(SSL_CTRL_SET_CURVES_LIST) || defined(SSL_CTX_set1_curves_list)
|
||||||
lua_pushstring(L, "curves_list");
|
lua_pushstring(L, "curves_list");
|
||||||
|
@ -609,6 +609,91 @@ static int set_curves_list(lua_State *L)
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set the protocols a client should send for ALPN.
|
||||||
|
*/
|
||||||
|
static int set_alpn(lua_State *L)
|
||||||
|
{
|
||||||
|
long ret;
|
||||||
|
size_t len;
|
||||||
|
p_context ctx = checkctx(L, 1);
|
||||||
|
const char *str = luaL_checklstring(L, 2, &len);
|
||||||
|
|
||||||
|
ret = SSL_CTX_set_alpn_protos(ctx->context, (const unsigned char*)str, len);
|
||||||
|
if (ret) {
|
||||||
|
lua_pushboolean(L, 0);
|
||||||
|
lua_pushfstring(L, "error setting ALPN (%s)", ERR_reason_error_string(ERR_get_error()));
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
lua_pushboolean(L, 1);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This standard callback calls the server's callback in Lua sapce.
|
||||||
|
* The server has to return a list in wire-format strings.
|
||||||
|
* This function uses a helper function to match server and client lists.
|
||||||
|
*/
|
||||||
|
static int alpn_cb(SSL *s, const unsigned char **out, unsigned char *outlen,
|
||||||
|
const unsigned char *in, unsigned int inlen, void *arg)
|
||||||
|
{
|
||||||
|
int ret;
|
||||||
|
size_t server_len;
|
||||||
|
const char *server;
|
||||||
|
p_context ctx = (p_context)arg;
|
||||||
|
lua_State *L = ctx->L;
|
||||||
|
|
||||||
|
luaL_getmetatable(L, "SSL:ALPN:Registry");
|
||||||
|
lua_pushlightuserdata(L, (void*)ctx->context);
|
||||||
|
lua_gettable(L, -2);
|
||||||
|
|
||||||
|
lua_pushlstring(L, (const char*)in, inlen);
|
||||||
|
|
||||||
|
lua_call(L, 1, 1);
|
||||||
|
|
||||||
|
if (!lua_isstring(L, -1)) {
|
||||||
|
lua_pop(L, 2);
|
||||||
|
return SSL_TLSEXT_ERR_NOACK;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Protocol list from server in wire-format string
|
||||||
|
server = luaL_checklstring(L, -1, &server_len);
|
||||||
|
ret = SSL_select_next_proto((unsigned char**)out, outlen, (const unsigned char*)server,
|
||||||
|
server_len, in, inlen);
|
||||||
|
if (ret != OPENSSL_NPN_NEGOTIATED) {
|
||||||
|
lua_pop(L, 2);
|
||||||
|
return SSL_TLSEXT_ERR_NOACK;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy the result because lua_pop() can collect the pointer
|
||||||
|
ctx->alpn = malloc(*outlen);
|
||||||
|
memcpy(ctx->alpn, (void*)*out, *outlen);
|
||||||
|
*out = (const unsigned char*)ctx->alpn;
|
||||||
|
|
||||||
|
lua_pop(L, 2);
|
||||||
|
|
||||||
|
return SSL_TLSEXT_ERR_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set a callback a server can use to select the next protocol with ALPN.
|
||||||
|
*/
|
||||||
|
static int set_alpn_cb(lua_State *L)
|
||||||
|
{
|
||||||
|
p_context ctx = checkctx(L, 1);
|
||||||
|
|
||||||
|
luaL_getmetatable(L, "SSL:ALPN:Registry");
|
||||||
|
lua_pushlightuserdata(L, (void*)ctx->context);
|
||||||
|
lua_pushvalue(L, 2);
|
||||||
|
lua_settable(L, -3);
|
||||||
|
|
||||||
|
SSL_CTX_set_alpn_select_cb(ctx->context, alpn_cb, ctx);
|
||||||
|
|
||||||
|
lua_pushboolean(L, 1);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Package functions
|
* Package functions
|
||||||
*/
|
*/
|
||||||
@ -618,6 +703,8 @@ static luaL_Reg funcs[] = {
|
|||||||
{"loadcert", load_cert},
|
{"loadcert", load_cert},
|
||||||
{"loadkey", load_key},
|
{"loadkey", load_key},
|
||||||
{"checkkey", check_key},
|
{"checkkey", check_key},
|
||||||
|
{"setalpn", set_alpn},
|
||||||
|
{"setalpncb", set_alpn_cb},
|
||||||
{"setcipher", set_cipher},
|
{"setcipher", set_cipher},
|
||||||
{"setdepth", set_depth},
|
{"setdepth", set_depth},
|
||||||
{"setdhparam", set_dhparam},
|
{"setdhparam", set_dhparam},
|
||||||
@ -654,6 +741,10 @@ static int meth_destroy(lua_State *L)
|
|||||||
lua_pushlightuserdata(L, (void*)ctx->context);
|
lua_pushlightuserdata(L, (void*)ctx->context);
|
||||||
lua_pushnil(L);
|
lua_pushnil(L);
|
||||||
lua_settable(L, -3);
|
lua_settable(L, -3);
|
||||||
|
luaL_getmetatable(L, "SSL:ALPN:Registry");
|
||||||
|
lua_pushlightuserdata(L, (void*)ctx->context);
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_settable(L, -3);
|
||||||
|
|
||||||
SSL_CTX_free(ctx->context);
|
SSL_CTX_free(ctx->context);
|
||||||
ctx->context = NULL;
|
ctx->context = NULL;
|
||||||
@ -795,8 +886,9 @@ void *lsec_testudata (lua_State *L, int ud, const char *tname) {
|
|||||||
*/
|
*/
|
||||||
LSEC_API int luaopen_ssl_context(lua_State *L)
|
LSEC_API int luaopen_ssl_context(lua_State *L)
|
||||||
{
|
{
|
||||||
luaL_newmetatable(L, "SSL:DH:Registry"); /* Keep all DH callbacks */
|
luaL_newmetatable(L, "SSL:DH:Registry"); /* Keep all DH callbacks */
|
||||||
luaL_newmetatable(L, "SSL:Verify:Registry"); /* Keep all verify flags */
|
luaL_newmetatable(L, "SSL:ALPN:Registry"); /* Keep all ALPN callbacks */
|
||||||
|
luaL_newmetatable(L, "SSL:Verify:Registry"); /* Keep all verify flags */
|
||||||
luaL_newmetatable(L, "SSL:Context");
|
luaL_newmetatable(L, "SSL:Context");
|
||||||
setfuncs(L, meta);
|
setfuncs(L, meta);
|
||||||
|
|
||||||
|
@ -24,6 +24,7 @@ typedef struct t_context_ {
|
|||||||
SSL_CTX *context;
|
SSL_CTX *context;
|
||||||
lua_State *L;
|
lua_State *L;
|
||||||
DH *dh_param;
|
DH *dh_param;
|
||||||
|
void *alpn;
|
||||||
int mode;
|
int mode;
|
||||||
} t_context;
|
} t_context;
|
||||||
typedef t_context* p_context;
|
typedef t_context* p_context;
|
||||||
|
18
src/ssl.c
18
src/ssl.c
@ -389,6 +389,10 @@ static int meth_handshake(lua_State *L)
|
|||||||
DH_free(ctx->dh_param);
|
DH_free(ctx->dh_param);
|
||||||
ctx->dh_param = NULL;
|
ctx->dh_param = NULL;
|
||||||
}
|
}
|
||||||
|
if (ctx->alpn) {
|
||||||
|
free(ctx->alpn);
|
||||||
|
ctx->alpn = NULL;
|
||||||
|
}
|
||||||
if (err == IO_DONE) {
|
if (err == IO_DONE) {
|
||||||
lua_pushboolean(L, 1);
|
lua_pushboolean(L, 1);
|
||||||
return 1;
|
return 1;
|
||||||
@ -799,6 +803,19 @@ static int meth_getsniname(lua_State *L)
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static int meth_getalpn(lua_State *L)
|
||||||
|
{
|
||||||
|
unsigned len;
|
||||||
|
const unsigned char *data;
|
||||||
|
p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection");
|
||||||
|
SSL_get0_alpn_selected(ssl->ssl, &data, &len);
|
||||||
|
if (data == NULL && len == 0)
|
||||||
|
lua_pushnil(L);
|
||||||
|
else
|
||||||
|
lua_pushlstring(L, (const char*)data, len);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
static int meth_copyright(lua_State *L)
|
static int meth_copyright(lua_State *L)
|
||||||
{
|
{
|
||||||
lua_pushstring(L, "LuaSec 0.7 - Copyright (C) 2006-2018 Bruno Silvestre, UFG"
|
lua_pushstring(L, "LuaSec 0.7 - Copyright (C) 2006-2018 Bruno Silvestre, UFG"
|
||||||
@ -816,6 +833,7 @@ static int meth_copyright(lua_State *L)
|
|||||||
*/
|
*/
|
||||||
static luaL_Reg methods[] = {
|
static luaL_Reg methods[] = {
|
||||||
{"close", meth_close},
|
{"close", meth_close},
|
||||||
|
{"getalpn", meth_getalpn},
|
||||||
{"getfd", meth_getfd},
|
{"getfd", meth_getfd},
|
||||||
{"getfinished", meth_getfinished},
|
{"getfinished", meth_getfinished},
|
||||||
{"getpeercertificate", meth_getpeercertificate},
|
{"getpeercertificate", meth_getpeercertificate},
|
||||||
|
75
src/ssl.lua
75
src/ssl.lua
@ -30,6 +30,39 @@ local function optexec(func, param, ctx)
|
|||||||
return true
|
return true
|
||||||
end
|
end
|
||||||
|
|
||||||
|
--
|
||||||
|
-- Convert an array of strings to wire-format
|
||||||
|
--
|
||||||
|
local function array2wireformat(array)
|
||||||
|
local str = ""
|
||||||
|
for k, v in ipairs(array) do
|
||||||
|
if type(v) ~= "string" then return nil end
|
||||||
|
local len = #v
|
||||||
|
if len == 0 then
|
||||||
|
return nil, "invalid ALPN name (empty string)"
|
||||||
|
elseif len > 255 then
|
||||||
|
return nil, "invalid ALPN name (length > 255)"
|
||||||
|
end
|
||||||
|
str = str .. string.char(len) .. v
|
||||||
|
end
|
||||||
|
if str == "" then return nil, "invalid ALPN list (empty)" end
|
||||||
|
return str
|
||||||
|
end
|
||||||
|
|
||||||
|
--
|
||||||
|
-- Convert wire-string format to array
|
||||||
|
--
|
||||||
|
local function wireformat2array(str)
|
||||||
|
local i = 1
|
||||||
|
local array = {}
|
||||||
|
while i < #str do
|
||||||
|
local len = str:byte(i)
|
||||||
|
array[#array + 1] = str:sub(i + 1, i + len)
|
||||||
|
i = i + len + 1
|
||||||
|
end
|
||||||
|
return array
|
||||||
|
end
|
||||||
|
|
||||||
--
|
--
|
||||||
--
|
--
|
||||||
--
|
--
|
||||||
@ -113,6 +146,48 @@ local function newcontext(cfg)
|
|||||||
if not succ then return nil, msg end
|
if not succ then return nil, msg end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
-- ALPN
|
||||||
|
if cfg.mode == "server" and cfg.alpn then
|
||||||
|
if type(cfg.alpn) == "function" then
|
||||||
|
local alpncb = cfg.alpn
|
||||||
|
-- This callback function has to return one value only
|
||||||
|
succ, msg = context.setalpncb(ctx, function(str)
|
||||||
|
local protocols = alpncb(wireformat2array(str))
|
||||||
|
if type(protocols) == "string" then
|
||||||
|
protocols = { protocols }
|
||||||
|
elseif type(protocols) ~= "table" then
|
||||||
|
return nil
|
||||||
|
end
|
||||||
|
return (array2wireformat(protocols)) -- use "()" to drop error message
|
||||||
|
end)
|
||||||
|
if not succ then return nil, msg end
|
||||||
|
elseif type(cfg.alpn) == "table" then
|
||||||
|
local protocols = cfg.alpn
|
||||||
|
-- check if array is valid before use it
|
||||||
|
succ, msg = array2wireformat(protocols)
|
||||||
|
if not succ then return nil, msg end
|
||||||
|
-- This callback function has to return one value only
|
||||||
|
succ, msg = context.setalpncb(ctx, function()
|
||||||
|
return (array2wireformat(protocols)) -- use "()" to drop error message
|
||||||
|
end)
|
||||||
|
if not succ then return nil, msg end
|
||||||
|
else
|
||||||
|
return nil, "invalid ALPN parameter"
|
||||||
|
end
|
||||||
|
elseif cfg.mode == "client" and cfg.alpn then
|
||||||
|
local alpn
|
||||||
|
if type(cfg.alpn) == "string" then
|
||||||
|
alpn, msg = array2wireformat({ cfg.alpn })
|
||||||
|
elseif type(cfg.alpn) == "table" then
|
||||||
|
alpn, msg = array2wireformat(cfg.alpn)
|
||||||
|
else
|
||||||
|
return nil, "invalid ALPN parameter"
|
||||||
|
end
|
||||||
|
if not alpn then return nil, msg end
|
||||||
|
succ, msg = context.setalpn(ctx, alpn)
|
||||||
|
if not succ then return nil, msg end
|
||||||
|
end
|
||||||
|
|
||||||
return ctx
|
return ctx
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user