diff options
-rw-r--r-- | src/openssl.c | 106 |
1 files changed, 83 insertions, 23 deletions
diff --git a/src/openssl.c b/src/openssl.c index 2dee037..773930b 100644 --- a/src/openssl.c +++ b/src/openssl.c @@ -26,7 +26,7 @@ #ifndef LUAOSSL_H #define LUAOSSL_H -#include <limits.h> /* INT_MAX INT_MIN */ +#include <limits.h> /* INT_MAX INT_MIN UCHAR_MAX */ #include <stdint.h> /* uintptr_t */ #include <string.h> /* memset(3) strerror_r(3) */ #include <strings.h> /* strcasecmp(3) */ @@ -87,6 +87,10 @@ #define HAVE_SSL_CTX_SET_ALPN_PROTOS (OPENSSL_VERSION_NUMBER >= 0x1000200fL) #endif +#ifndef HAVE_SSL_SET_ALPN_PROTOS +#define HAVE_SSL_SET_ALPN_PROTOS HAVE_SSL_CTX_SET_ALPN_PROTOS +#endif + #ifndef HAVE_SSL_GET0_ALPN_SELECTED #define HAVE_SSL_GET0_ALPN_SELECTED HAVE_SSL_CTX_SET_ALPN_PROTOS #endif @@ -346,6 +350,45 @@ static int optencoding(lua_State *L, int index, const char *def, int allow) { } /* optencoding() */ +static _Bool rawgeti(lua_State *L, int index, int n) { + lua_rawgeti(L, index, n); + + if (lua_isnil(L, -1)) { + lua_pop(L, 1); + + return 0; + } else { + return 1; + } +} /* rawgeti() */ + + +/* check ALPN protocols and add to buffer of length-prefixed strings */ +static void checkprotos(luaL_Buffer *B, lua_State *L, int index) { + int n; + + luaL_checktype(L, index, LUA_TTABLE); + + for (n = 1; rawgeti(L, index, n); n++) { + const char *tmp; + size_t len; + + switch (lua_type(L, -1)) { + case LUA_TSTRING: + break; + default: + luaL_argerror(L, index, "array of strings expected"); + } + + tmp = luaL_checklstring(L, -1, &len); + luaL_argcheck(L, len > 0 && len <= UCHAR_MAX, index, "proto string length invalid"); + luaL_addchar(B, (unsigned char)len); + luaL_addlstring(B, tmp, len); + lua_pop(L, 1); + } +} /* checkprotos() */ + + static _Bool getfield(lua_State *L, int index, const char *k) { lua_getfield(L, index, k); @@ -4524,34 +4567,16 @@ static int sx_setEphemeralKey(lua_State *L) { return 1; } /* sx_setEphemeralKey() */ + #if HAVE_SSL_CTX_SET_ALPN_PROTOS static int sx_setAlpnProtos(lua_State *L) { SSL_CTX *ctx = checksimple(L, 1, SSL_CTX_CLASS); + luaL_Buffer B; size_t len; const char *tmp; - unsigned protos_len = 0; - luaL_Buffer B; - luaL_checktype(L, 2, LUA_TTABLE); - luaL_buffinit(L, &B); - while (1) { - protos_len++; - lua_rawgeti(L, 2, protos_len); - switch (lua_type(L, -1)) { - case LUA_TNIL: - goto done; - case LUA_TSTRING: - break; - default: - return luaL_argerror(L, 2, "array of strings expected"); - } - tmp = luaL_checklstring(L, -1, &len); - luaL_argcheck(L, len > 0 && len <= UCHAR_MAX, 2, "proto string length invalid"); - luaL_addchar(&B, (unsigned char)len); - luaL_addlstring(&B, tmp, len); - lua_pop(L, 1); - } -done: + luaL_buffinit(L, &B); + checkprotos(&B, L, 2); luaL_pushresult(&B); tmp = lua_tolstring(L, -1, &len); @@ -4571,6 +4596,7 @@ done: } /* sx_setAlpnProtos() */ #endif + static int sx__gc(lua_State *L) { SSL_CTX **ud = luaL_checkudata(L, 1, SSL_CTX_CLASS); @@ -4847,6 +4873,7 @@ static int ssl_getClientVersion(lua_State *L) { return 1; } /* ssl_getClientVersion() */ + #if HAVE_SSL_GET0_ALPN_SELECTED static int ssl_getAlpnSelected(lua_State *L) { SSL *ssl = checksimple(L, 1, SSL_CLASS); @@ -4862,6 +4889,36 @@ static int ssl_getAlpnSelected(lua_State *L) { } /* ssl_getAlpnSelected() */ #endif + +#if HAVE_SSL_SET_ALPN_PROTOS +static int ssl_setAlpnProtos(lua_State *L) { + SSL *ssl = checksimple(L, 1, SSL_CLASS); + luaL_Buffer B; + size_t len; + const char *tmp; + + luaL_buffinit(L, &B); + checkprotos(&B, L, 2); + luaL_pushresult(&B); + tmp = lua_tolstring(L, -1, &len); + + /* OpenSSL 1.0.2 doesn't update the error stack on failure. */ + ERR_clear_error(); + if (0 != SSL_set_alpn_protos(ssl, (const unsigned char*)tmp, len)) { + if (!ERR_peek_error()) { + return luaL_error(L, "unable to set ALPN protocols: %s", xstrerror(ENOMEM)); + } else { + return throwssl(L, "ssl:setAlpnProtos"); + } + } + + lua_pushboolean(L, 1); + + return 1; +} /* ssl_setAlpnProtos() */ +#endif + + static int ssl__gc(lua_State *L) { SSL **ud = luaL_checkudata(L, 1, SSL_CLASS); @@ -4888,6 +4945,9 @@ static const luaL_Reg ssl_methods[] = { #if HAVE_SSL_GET0_ALPN_SELECTED { "getAlpnSelected", &ssl_getAlpnSelected }, #endif +#if HAVE_SSL_SET_ALPN_PROTOS + { "setAlpnProtos", &ssl_setAlpnProtos }, +#endif { NULL, NULL }, }; |