From afeb65d5c21019132a9b7759f864f19e0ad0df43 Mon Sep 17 00:00:00 2001 From: william Date: Thu, 5 Mar 2015 14:57:45 -0800 Subject: add openssl.ssl:setAlpnProtos --- src/openssl.c | 106 +++++++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 83 insertions(+), 23 deletions(-) (limited to 'src') diff --git a/src/openssl.c b/src/openssl.c index 2dee037..773930b 100644 --- a/src/openssl.c +++ b/src/openssl.c @@ -26,7 +26,7 @@ #ifndef LUAOSSL_H #define LUAOSSL_H -#include /* INT_MAX INT_MIN */ +#include /* INT_MAX INT_MIN UCHAR_MAX */ #include /* uintptr_t */ #include /* memset(3) strerror_r(3) */ #include /* strcasecmp(3) */ @@ -87,6 +87,10 @@ #define HAVE_SSL_CTX_SET_ALPN_PROTOS (OPENSSL_VERSION_NUMBER >= 0x1000200fL) #endif +#ifndef HAVE_SSL_SET_ALPN_PROTOS +#define HAVE_SSL_SET_ALPN_PROTOS HAVE_SSL_CTX_SET_ALPN_PROTOS +#endif + #ifndef HAVE_SSL_GET0_ALPN_SELECTED #define HAVE_SSL_GET0_ALPN_SELECTED HAVE_SSL_CTX_SET_ALPN_PROTOS #endif @@ -346,6 +350,45 @@ static int optencoding(lua_State *L, int index, const char *def, int allow) { } /* optencoding() */ +static _Bool rawgeti(lua_State *L, int index, int n) { + lua_rawgeti(L, index, n); + + if (lua_isnil(L, -1)) { + lua_pop(L, 1); + + return 0; + } else { + return 1; + } +} /* rawgeti() */ + + +/* check ALPN protocols and add to buffer of length-prefixed strings */ +static void checkprotos(luaL_Buffer *B, lua_State *L, int index) { + int n; + + luaL_checktype(L, index, LUA_TTABLE); + + for (n = 1; rawgeti(L, index, n); n++) { + const char *tmp; + size_t len; + + switch (lua_type(L, -1)) { + case LUA_TSTRING: + break; + default: + luaL_argerror(L, index, "array of strings expected"); + } + + tmp = luaL_checklstring(L, -1, &len); + luaL_argcheck(L, len > 0 && len <= UCHAR_MAX, index, "proto string length invalid"); + luaL_addchar(B, (unsigned char)len); + luaL_addlstring(B, tmp, len); + lua_pop(L, 1); + } +} /* checkprotos() */ + + static _Bool getfield(lua_State *L, int index, const char *k) { lua_getfield(L, index, k); @@ -4524,34 +4567,16 @@ static int sx_setEphemeralKey(lua_State *L) { return 1; } /* sx_setEphemeralKey() */ + #if HAVE_SSL_CTX_SET_ALPN_PROTOS static int sx_setAlpnProtos(lua_State *L) { SSL_CTX *ctx = checksimple(L, 1, SSL_CTX_CLASS); + luaL_Buffer B; size_t len; const char *tmp; - unsigned protos_len = 0; - luaL_Buffer B; - luaL_checktype(L, 2, LUA_TTABLE); - luaL_buffinit(L, &B); - while (1) { - protos_len++; - lua_rawgeti(L, 2, protos_len); - switch (lua_type(L, -1)) { - case LUA_TNIL: - goto done; - case LUA_TSTRING: - break; - default: - return luaL_argerror(L, 2, "array of strings expected"); - } - tmp = luaL_checklstring(L, -1, &len); - luaL_argcheck(L, len > 0 && len <= UCHAR_MAX, 2, "proto string length invalid"); - luaL_addchar(&B, (unsigned char)len); - luaL_addlstring(&B, tmp, len); - lua_pop(L, 1); - } -done: + luaL_buffinit(L, &B); + checkprotos(&B, L, 2); luaL_pushresult(&B); tmp = lua_tolstring(L, -1, &len); @@ -4571,6 +4596,7 @@ done: } /* sx_setAlpnProtos() */ #endif + static int sx__gc(lua_State *L) { SSL_CTX **ud = luaL_checkudata(L, 1, SSL_CTX_CLASS); @@ -4847,6 +4873,7 @@ static int ssl_getClientVersion(lua_State *L) { return 1; } /* ssl_getClientVersion() */ + #if HAVE_SSL_GET0_ALPN_SELECTED static int ssl_getAlpnSelected(lua_State *L) { SSL *ssl = checksimple(L, 1, SSL_CLASS); @@ -4862,6 +4889,36 @@ static int ssl_getAlpnSelected(lua_State *L) { } /* ssl_getAlpnSelected() */ #endif + +#if HAVE_SSL_SET_ALPN_PROTOS +static int ssl_setAlpnProtos(lua_State *L) { + SSL *ssl = checksimple(L, 1, SSL_CLASS); + luaL_Buffer B; + size_t len; + const char *tmp; + + luaL_buffinit(L, &B); + checkprotos(&B, L, 2); + luaL_pushresult(&B); + tmp = lua_tolstring(L, -1, &len); + + /* OpenSSL 1.0.2 doesn't update the error stack on failure. */ + ERR_clear_error(); + if (0 != SSL_set_alpn_protos(ssl, (const unsigned char*)tmp, len)) { + if (!ERR_peek_error()) { + return luaL_error(L, "unable to set ALPN protocols: %s", xstrerror(ENOMEM)); + } else { + return throwssl(L, "ssl:setAlpnProtos"); + } + } + + lua_pushboolean(L, 1); + + return 1; +} /* ssl_setAlpnProtos() */ +#endif + + static int ssl__gc(lua_State *L) { SSL **ud = luaL_checkudata(L, 1, SSL_CLASS); @@ -4887,6 +4944,9 @@ static const luaL_Reg ssl_methods[] = { { "getClientVersion", &ssl_getClientVersion }, #if HAVE_SSL_GET0_ALPN_SELECTED { "getAlpnSelected", &ssl_getAlpnSelected }, +#endif +#if HAVE_SSL_SET_ALPN_PROTOS + { "setAlpnProtos", &ssl_setAlpnProtos }, #endif { NULL, NULL }, }; -- cgit v1.2.3-59-g8ed1b