diff --git a/samples/sni/client.lua b/samples/sni/client.lua index 1487098..79eb004 100644 --- a/samples/sni/client.lua +++ b/samples/sni/client.lua @@ -19,7 +19,8 @@ conn = ssl.wrap(conn, params) -- Comment the lines to not send a name --conn:sni("servera.br") -conn:sni("serveraa.br") +--conn:sni("serveraa.br") +conn:sni("serverb.br") assert(conn:dohandshake()) -- diff --git a/samples/sni/server.lua b/samples/sni/server.lua index 8ac4be2..101bc2f 100644 --- a/samples/sni/server.lua +++ b/samples/sni/server.lua @@ -39,10 +39,12 @@ local conn = server:accept() conn = ssl.wrap(conn, ctx01) -- Configure the name map -conn:sni({ +local sni_map = { ["servera.br"] = ctx01, ["serveraa.br"] = ctx02, -}) +} + +conn:sni(sni_map, true) assert(conn:dohandshake()) -- diff --git a/src/ssl.c b/src/ssl.c index 92d0881..b5591ac 100644 --- a/src/ssl.c +++ b/src/ssl.c @@ -660,6 +660,7 @@ static int meth_info(lua_State *L) static int sni_cb(SSL *ssl, int *ad, void *arg) { + int strict; SSL_CTX *newctx = NULL; SSL_CTX *ctx = SSL_get_SSL_CTX(ssl); lua_State *L = ((p_context)SSL_CTX_get_app_data(ctx))->L; @@ -667,41 +668,54 @@ static int sni_cb(SSL *ssl, int *ad, void *arg) /* No name, use default context */ if (!name) return SSL_TLSEXT_ERR_NOACK; - /* Search for the name in the map */ + /* Retrieve struct from registry */ luaL_getmetatable(L, "SSL:SNI:Registry"); lua_pushlightuserdata(L, (void*)ssl); lua_gettable(L, -2); + /* Strict search? */ + lua_pushstring(L, "strict"); + lua_gettable(L, -2); + strict = lua_toboolean(L, -1); + lua_pop(L, 1); + /* Search for the name in the map */ + lua_pushstring(L, "map"); + lua_gettable(L, -2); lua_pushstring(L, name); lua_gettable(L, -2); if (lua_isuserdata(L, -1)) newctx = lsec_checkcontext(L, -1); - lua_pop(L, 3); + lua_pop(L, 4); + /* Found, use this context */ if (newctx) { SSL_set_SSL_CTX(ssl, newctx); return SSL_TLSEXT_ERR_OK; } + /* Not found, but use initial context */ + if (!strict) + return SSL_TLSEXT_ERR_OK; return SSL_TLSEXT_ERR_ALERT_FATAL; } static int meth_sni(lua_State *L) { + int strict; SSL_CTX *aux; const char *name; p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection"); SSL_CTX *ctx = SSL_get_SSL_CTX(ssl->ssl); p_context pctx = (p_context)SSL_CTX_get_app_data(ctx); - switch (pctx->mode) { - case LSEC_MODE_CLIENT: + if (pctx->mode == LSEC_MODE_CLIENT) { name = luaL_checkstring(L, 2); SSL_set_tlsext_host_name(ssl->ssl, name); - break; - case LSEC_MODE_SERVER: + return 0; + } else if (pctx->mode == LSEC_MODE_SERVER) { luaL_checktype(L, 2, LUA_TTABLE); + strict = lua_toboolean(L, 3); /* Check if the table contains only (string -> context) */ lua_pushnil(L); while (lua_next(L, 2)) { - luaL_checkstring(L, 3); - aux = lsec_checkcontext(L, 4); + luaL_checkstring(L, -2); + aux = lsec_checkcontext(L, -1); /* Set callback in every context */ SSL_CTX_set_tlsext_servername_callback(aux, sni_cb); /* leave the next key on the stack */ @@ -710,15 +724,31 @@ static int meth_sni(lua_State *L) /* Save table in the register */ luaL_getmetatable(L, "SSL:SNI:Registry"); lua_pushlightuserdata(L, (void*)ssl->ssl); + lua_newtable(L); + lua_pushstring(L, "map"); lua_pushvalue(L, 2); lua_settable(L, -3); + lua_pushstring(L, "strict"); + lua_pushboolean(L, strict); + lua_settable(L, -3); + lua_settable(L, -3); /* Set callback in the default context */ SSL_CTX_set_tlsext_servername_callback(ctx, sni_cb); - break; } return 0; } +static int meth_getsniname(lua_State *L) +{ + p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection"); + const char *name = SSL_get_servername(ssl->ssl, TLSEXT_NAMETYPE_host_name); + if (name) + lua_pushstring(L, name); + else + lua_pushnil(L); + return 1; +} + static int meth_copyright(lua_State *L) { lua_pushstring(L, "LuaSec 0.5 - Copyright (C) 2006-2011 Bruno Silvestre" @@ -742,6 +772,7 @@ static luaL_Reg methods[] = { {"getpeerchain", meth_getpeerchain}, {"getpeerverification", meth_getpeerverification}, {"getpeerfinished", meth_getpeerfinished}, + {"getsniname", meth_getsniname}, {"getstats", meth_getstats}, {"setstats", meth_setstats}, {"dirty", meth_dirty},