From 4cecbb2783c758b683954a0d6b1dcef273d22398 Mon Sep 17 00:00:00 2001 From: Matthew Wild Date: Wed, 21 Sep 2022 18:40:10 +0100 Subject: [PATCH] ssl: Add :getlocalchain() + :getlocalcertificate() to mirror the peer methods These methods mirror the existing methods that fetch the peer certificate and chain. Due to various factors (SNI, multiple key types, etc.) it is not always trivial for an application to determine what certificate was presented to the client. However there are various use-cases where this is needed, such as tls-server-end-point channel binding and OCSP stapling. Requires OpenSSL 1.0.2+ (note: SSL_get_certificate() has existed for a very long time, but was lacking documentation until OpenSSL 3.0). --- samples/chain/server.lua | 21 +++++++++- src/ssl.c | 89 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 109 insertions(+), 1 deletion(-) diff --git a/samples/chain/server.lua b/samples/chain/server.lua index a560fbe..8bf3d36 100644 --- a/samples/chain/server.lua +++ b/samples/chain/server.lua @@ -31,8 +31,27 @@ util.show( conn:getpeercertificate() ) print("----------------------------------------------------------------------") -for k, cert in ipairs( conn:getpeerchain() ) do +local expectedpeerchain = { "../certs/clientAcert.pem", "../certs/rootA.pem" } + +local peerchain = conn:getpeerchain() +assert(#peerchain == #expectedpeerchain) +for k, cert in ipairs( peerchain ) do util.show(cert) + local expectedpem = assert(io.open(expectedpeerchain[k])):read("*a") + assert(cert:pem() == expectedpem, "peer chain mismatch @ "..tostring(k)) +end + +local expectedlocalchain = { "../certs/serverAcert.pem" } + +local localchain = assert(conn:getlocalchain()) +assert(#localchain == #expectedlocalchain) +for k, cert in ipairs( localchain ) do + util.show(cert) + local expectedpem = assert(io.open(expectedlocalchain[k])):read("*a") + assert(cert:pem() == expectedpem, "local chain mismatch @ "..tostring(k)) + if k == 1 then + assert(cert:pem() == conn:getlocalcertificate():pem()) + end end local f = io.open(params.certificate) diff --git a/src/ssl.c b/src/ssl.c index c546a87..ce83c9d 100644 --- a/src/ssl.c +++ b/src/ssl.c @@ -530,6 +530,58 @@ static int meth_getpeercertificate(lua_State *L) return 1; } +/** + * Return the nth certificate of the chain sent to our peer. + */ +static int meth_getlocalcertificate(lua_State *L) +{ + int n; + X509 *cert; + STACK_OF(X509) *certs; + p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection"); + if (ssl->state != LSEC_STATE_CONNECTED) { + lua_pushnil(L); + lua_pushstring(L, "closed"); + return 2; + } + /* Default to the first cert */ + n = (int)luaL_optinteger(L, 2, 1); + /* This function is 1-based, but OpenSSL is 0-based */ + --n; + if (n < 0) { + lua_pushnil(L); + lua_pushliteral(L, "invalid certificate index"); + return 2; + } + if (n == 0) { + cert = SSL_get_certificate(ssl->ssl); + if (cert) + lsec_pushx509(L, cert); + else + lua_pushnil(L); + return 1; + } + /* In a server-context, the stack doesn't contain the peer cert, + * so adjust accordingly. + */ + if (SSL_is_server(ssl->ssl)) + --n; + if(SSL_get0_chain_certs(ssl->ssl, &certs) != 1) { + lua_pushnil(L); + } else { + if (n >= sk_X509_num(certs)) { + lua_pushnil(L); + return 1; + } + cert = sk_X509_value(certs, n); + /* Increment the reference counting of the object. */ + /* See SSL_get_peer_certificate() source code. */ + X509_up_ref(cert); + lsec_pushx509(L, cert); + } + return 1; +} + /** * Return the chain of certificate of the peer. */ @@ -564,6 +616,41 @@ static int meth_getpeerchain(lua_State *L) return 1; } +/** + * Return the chain of certificates sent to the peer. + */ +static int meth_getlocalchain(lua_State *L) +{ + int i; + int idx = 1; + int n_certs; + X509 *cert; + STACK_OF(X509) *certs; + p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection"); + if (ssl->state != LSEC_STATE_CONNECTED) { + lua_pushnil(L); + lua_pushstring(L, "closed"); + return 2; + } + lua_newtable(L); + if (SSL_is_server(ssl->ssl)) { + lsec_pushx509(L, SSL_get_certificate(ssl->ssl)); + lua_rawseti(L, -2, idx++); + } + if(SSL_get0_chain_certs(ssl->ssl, &certs)) { + n_certs = sk_X509_num(certs); + for (i = 0; i < n_certs; i++) { + cert = sk_X509_value(certs, i); + /* Increment the reference counting of the object. */ + /* See SSL_get_peer_certificate() source code. */ + X509_up_ref(cert); + lsec_pushx509(L, cert); + lua_rawseti(L, -2, idx++); + } + } + return 1; +} + /** * Copy the table src to the table dst. */ @@ -908,7 +995,9 @@ static luaL_Reg methods[] = { {"getfd", meth_getfd}, {"getfinished", meth_getfinished}, {"getpeercertificate", meth_getpeercertificate}, + {"getlocalcertificate", meth_getlocalcertificate}, {"getpeerchain", meth_getpeerchain}, + {"getlocalchain", meth_getlocalchain}, {"getpeerverification", meth_getpeerverification}, {"getpeerfinished", meth_getpeerfinished}, {"exportkeyingmaterial",meth_exportkeyingmaterial},