From 55da052192e151ee055e62a378d6ebdbdcbe5087 Mon Sep 17 00:00:00 2001
From: william <william@25thandclement.com>
Date: Thu, 9 Apr 2015 20:53:43 -0700
Subject: refactor and fixup some interfaces, and begin to flesh out ALPN
 selection callback

---
 src/openssl.c | 187 ++++++++++++++++++++++++++++++++++++++++++++++------------
 1 file changed, 148 insertions(+), 39 deletions(-)

diff --git a/src/openssl.c b/src/openssl.c
index 4e0b898..3465922 100644
--- a/src/openssl.c
+++ b/src/openssl.c
@@ -424,6 +424,24 @@ static void checkprotos(luaL_Buffer *B, lua_State *L, int index) {
 	}
 } /* checkprotos() */
 
+static void pushprotos(lua_State *L, const unsigned char *p, size_t n) {
+	const unsigned char *pe = &p[n];
+	int i = 0;
+
+	lua_newtable(L);
+
+	while (p < pe) {
+		n = *p++;
+
+		if ((size_t)(pe - p) < n)
+			luaL_error(L, "corrupt ALPN protocol list (%zu > %zu)", n, (size_t)(pe - p));
+
+		lua_pushlstring(L, (const void *)p, n);
+		lua_rawseti(L, -2, ++i);
+		p += n;
+	}
+} /* pushprotos() */
+
 
 static _Bool getfield(lua_State *L, int index, const char *k) {
 	lua_getfield(L, index, k);
@@ -604,15 +622,44 @@ static void *compat_EVP_PKEY_get0(EVP_PKEY *key) {
 #endif
 
 
+typedef int auxref_t;
+typedef int auxtype_t;
+
+static void auxL_unref(lua_State *L, auxref_t *ref) {
+	luaL_unref(L, LUA_REGISTRYINDEX, *ref);
+	*ref = LUA_NOREF;
+} /* auxL_unref() */
+
+static void auxL_ref(lua_State *L, int index, auxref_t *ref) {
+	auxL_unref(L, ref);
+	lua_pushvalue(L, index);
+	*ref = luaL_ref(L, LUA_REGISTRYINDEX);
+} /* auxL_ref() */
+
+static auxtype_t auxL_getref(lua_State *L, auxref_t ref) {
+	if (ref == LUA_NOREF || ref == LUA_REFNIL) {
+		lua_pushnil(L);
+	} else {
+		lua_rawgeti(L, LUA_REGISTRYINDEX, ref);
+	}
+
+	return lua_type(L, -1);
+} /* auxL_getref() */
+
+
 struct ex_state {
-	lua_State *mainthread;
+	lua_State *L;
 	LIST_HEAD(, ex_data) data;
 }; /* struct ex_state */
 
+#ifndef EX_DATA_MAXARGS
+#define EX_DATA_MAXARGS 4
+#endif
+
 struct ex_data {
 	struct ex_state *state;
 	int refs;
-	int arg[4];
+	auxref_t arg[EX_DATA_MAXARGS];
 	LIST_ENTRY(ex_data) le;
 }; /* struct ex_data */
 
@@ -621,40 +668,47 @@ enum {
 };
 
 static struct ex_type {
-	int class_index;
-	int index;
+	int class_index; /* OpenSSL object type identifier */
+	int index; /* OpenSSL-allocated external data identifier */
 	void *(*get_ex_data)();
 	int (*set_ex_data)();
 } ex_type[] = {
 	[EX_SSL_CTX_ALPN_SELECT_CB] = { CRYPTO_EX_INDEX_SSL_CTX, -1, &SSL_CTX_get_ex_data, &SSL_CTX_set_ex_data },
 };
 
-static int ex_data_dup(CRYPTO_EX_DATA *to NOTUSED, CRYPTO_EX_DATA *from NOTUSED, void *from_d, int idx NOTUSED, long argl NOTUSED, void *argp NOTUSED) {
+static int ex_ondup(CRYPTO_EX_DATA *to NOTUSED, CRYPTO_EX_DATA *from NOTUSED, void *from_d, int idx NOTUSED, long argl NOTUSED, void *argp NOTUSED) {
 	struct ex_data **data = from_d;
 
 	if (*data)
 		(*data)->refs++;
 
 	return 1;
-} /* ex_data_dup() */
+} /* ex_ondup() */
 
-static void ex_data_free(void *parent NOTUSED, void *_data, CRYPTO_EX_DATA *ad NOTUSED, int idx NOTUSED, long argl NOTUSED, void *argp NOTUSED) {
+static void ex_onfree(void *parent NOTUSED, void *_data, CRYPTO_EX_DATA *ad NOTUSED, int idx NOTUSED, long argl NOTUSED, void *argp NOTUSED) {
 	struct ex_data *data = _data;
 
 	if (!data || --data->refs > 0)
 		return;
 
-	if (data->state)
+	if (data->state) {
+		int i;
+
+		for (i = 0; i < (int)countof(data->arg); i++) {
+			auxL_unref(data->state->L, &data->arg[i]);
+		}
+
 		LIST_REMOVE(data, le);
+	}
 
 	free(data);
-} /* ex_data_free() */
+} /* ex_onfree() */
 
 static int ex_initonce(void) {
 	struct ex_type *type;
 
 	for (type = ex_type; type < endof(ex_type); type++) {
-		if (-1 == (type->index = CRYPTO_get_ex_new_index(type->class_index, 0, NULL, NULL, &ex_data_dup, &ex_data_free)))
+		if (-1 == (type->index = CRYPTO_get_ex_new_index(type->class_index, 0, NULL, NULL, &ex_ondup, &ex_onfree)))
 			return -1;
 	};
 
@@ -676,22 +730,27 @@ static int ex__gc(lua_State *L) {
 	return 0;
 } /* ex__gc() */
 
-static void ex_init(lua_State *L) {
+static void ex_newstate(lua_State *L) {
 	struct ex_state *state;
 	struct lua_State *thr;
 
 	state = prepudata(L, sizeof *state, NULL, &ex__gc);
 	LIST_INIT(&state->data);
 
+	/*
+	 * XXX: Don't reuse mainthread because if an error occurs in a
+	 * callback Lua might longjmp across the OpenSSL call stack.
+	 * Instead, we'll install our own panic handlers.
+	 */
 #if defined LUA_RIDX_MAINTHREAD
 	lua_rawgeti(L, LUA_REGISTRYINDEX, LUA_RIDX_MAINTHREAD);
-	state->mainthread = lua_tothread(L, -1);
+	state->L = lua_tothread(L, -1);
 	lua_pop(L, 1);
 #else
 	lua_pushvalue(L, -1);
 	thr = lua_newthread(L);
 	lua_settable(L, LUA_REGISTRYINDEX);
-	state->mainthread = thr;
+	state->L = thr;
 #endif
 
 	lua_pushcfunction(L, &ex__gc);
@@ -699,9 +758,9 @@ static void ex_init(lua_State *L) {
 	lua_settable(L, LUA_REGISTRYINDEX);
 
 	lua_pop(L, 1);
-} /* ex_init() */
+} /* ex_newstate() */
 
-static struct ex_state *ex_get(lua_State *L) {
+static struct ex_state *ex_getstate(lua_State *L) {
 	struct ex_state *state;
 
 	lua_pushcfunction(L, &ex__gc);
@@ -712,12 +771,12 @@ static struct ex_state *ex_get(lua_State *L) {
 	lua_pop(L, 1);
 
 	return state;
-} /* ex_get() */
+} /* ex_getstate() */
 
-static int ex_data_get(lua_State **L, int _type, void *obj) {
+static size_t ex_getdata(lua_State **L, int _type, void *obj) {
 	struct ex_type *type = &ex_type[_type];
 	struct ex_data *data;
-	int i;
+	size_t i;
 
 	if (!(data = type->get_ex_data(obj, type->index)))
 		return 0;
@@ -725,28 +784,31 @@ static int ex_data_get(lua_State **L, int _type, void *obj) {
 		return 0;
 
 	if (!*L)
-		*L = data->state->mainthread;
+		*L = data->state->L;
+
+	if (!lua_checkstack(*L, countof(data->arg)))
+		return 0;
 
-	for (i = 0; i < (int)countof(data->arg); i++) {
+	for (i = 0; i < countof(data->arg) && data->arg[i] != LUA_NOREF; i++) {
 		lua_rawgeti(*L, LUA_REGISTRYINDEX, data->arg[i]);
 	}
 
 	return i;
-} /* ex_data_get() */
+} /* ex_getdata() */
 
-static int ex_data_set(lua_State *L, int _type, void *obj, int n) {
+/* returns 0 on success, otherwise error (>0 == errno, -1 == OpenSSL error) */
+static int ex_setdata(lua_State *L, int _type, void *obj, size_t n) {
 	struct ex_type *type = &ex_type[_type];
 	struct ex_state *state;
 	struct ex_data *data;
-	int i, j;
+	size_t i, j;
 
 	if ((data = type->get_ex_data(obj, type->index)) && data->state) {
-		for (i = 0; i < (int)countof(data->arg); i++) {
-			luaL_unref(L, LUA_REGISTRYINDEX, data->arg[i]);
-			data->arg[i] = LUA_NOREF;
+		for (i = 0; i < countof(data->arg); i++) {
+			auxL_unref(L, &data->arg[i]);
 		}
 	} else {
-		state = ex_get(L);
+		state = ex_getstate(L);
 
 		if (!(data = malloc(sizeof *data)))
 			return errno;
@@ -756,20 +818,19 @@ static int ex_data_set(lua_State *L, int _type, void *obj, int n) {
 
 		data->state = state;
 		data->refs = 1;
-		for (i = 0; i < (int)countof(data->arg); i++)
+		for (i = 0; i < countof(data->arg); i++)
 			data->arg[i] = LUA_NOREF;
 		LIST_INSERT_HEAD(&state->data, data, le);
 	}
 
-	for (i = n, j = 0; i > 0 && j < (int)countof(data->arg); i--, j++) {
-		lua_pushvalue(L, -i);
-		data->arg[j] = luaL_ref(L, LUA_REGISTRYINDEX);
+	for (i = n, j = 0; i > 0 && j < countof(data->arg); i--, j++) {
+		auxL_ref(L, -(int)i, &data->arg[j]);
 	}
 
 	lua_pop(L, n);
 
 	return 0;
-} /* ex_data_set() */
+} /* ex_setdata() */
 
 static void initall(lua_State *L);
 
@@ -4800,13 +4861,44 @@ static int sx_setAlpnProtos(lua_State *L) {
 #endif
 
 #if HAVE_SSL_CTX_SET_ALPN_SELECT_CB
-static int sx_setAlpnSelect_cb(SSL *ssl, const unsigned char **out, unsigned char *outlen, const unsigned char *in, unsigned int inlen, void *arg) {
+static SSL *ssl_push(lua_State *, SSL *);
+
+static int sx_setAlpnSelect_cb(SSL *ssl, const unsigned char **out, unsigned char *outlen, const unsigned char *in, unsigned int inlen, void *_ctx) {
+	SSL_CTX *ctx = _ctx;
 	lua_State *L = NULL;
-	int n;
+	size_t n;
+	int top, status;
 
-	n = ex_data_get(&L, EX_SSL_CTX_ALPN_SELECT_CB, ssl);
+	if (0 == (n = ex_getdata(&L, EX_SSL_CTX_ALPN_SELECT_CB, ctx)))
+		return SSL_TLSEXT_ERR_ALERT_FATAL;
 
-	return 0;
+	top = lua_gettop(L) - n;
+
+	/* TODO: Install temporary panic handler to catch OOM errors */
+
+	/* pass the SSL object as first argument */
+	ssl_push(L, ssl);
+	pushprotos(L, in, inlen);
+
+	/* TODO: lua_rotate ssl and protocols table into position. */
+
+	if (LUA_OK != (status = lua_pcall(L, 2 + (n - 1), 1, 0)))
+		goto fatal;
+
+	/* TODO: check return value */
+	(void)out; (void)outlen;
+
+	lua_settop(L, top);
+
+	return SSL_TLSEXT_ERR_OK;
+fatal:
+	lua_settop(L, top);
+
+	return SSL_TLSEXT_ERR_ALERT_FATAL;
+noack:
+	lua_settop(L, top);
+
+	return SSL_TLSEXT_ERR_NOACK;
 } /* sx_setAlpnSelect_cb() */
 
 static int sx_setAlpnSelect(lua_State *L) {
@@ -4815,9 +4907,17 @@ static int sx_setAlpnSelect(lua_State *L) {
 	int error;
 
 	luaL_checktype(L, 2, LUA_TFUNCTION);
-	error = ex_data_set(L, EX_SSL_CTX_ALPN_SELECT_CB, ctx, 1);
+	if ((error = ex_setdata(L, EX_SSL_CTX_ALPN_SELECT_CB, ctx, 1))) {
+		if (error > 0) {
+			return luaL_error(L, "unable to set ALPN protocol selection callback: %s", xstrerror(error));
+		} else if (!ERR_peek_error()) {
+			return luaL_error(L, "unable to set ALPN protocol selection callback: Unknown internal error");
+		} else {
+			return throwssl(L, "ssl.context:setAlpnSelect");
+		}
+	}
 
-	SSL_CTX_set_alpn_select_cb(ctx, &sx_setAlpnSelect_cb, NULL);
+	SSL_CTX_set_alpn_select_cb(ctx, &sx_setAlpnSelect_cb, ctx);
 
 	lua_pushboolean(L, 1);
 
@@ -4936,6 +5036,15 @@ int luaopen__openssl_ssl_context(lua_State *L) {
  *
  * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
 
+static SSL *ssl_push(lua_State *L, SSL *ssl) {
+	SSL **ud = prepsimple(L, SSL_CLASS);
+
+	CRYPTO_add(&(ssl)->references, 1, CRYPTO_LOCK_SSL);
+	*ud = ssl;
+
+	return *ud;
+} /* ssl_push() */
+
 static int ssl_new(lua_State *L) {
 	lua_pushnil(L);
 
@@ -6168,7 +6277,7 @@ static void initall(lua_State *L) {
 
 	pthread_mutex_unlock(&mutex);
 
-	ex_init(L);
+	ex_newstate(L);
 
 	addclass(L, BIGNUM_CLASS, bn_methods, bn_metatable);
 	addclass(L, PKEY_CLASS, pk_methods, pk_metatable);
-- 
cgit v1.2.3-59-g8ed1b