aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/openssl.c62
1 files changed, 42 insertions, 20 deletions
diff --git a/src/openssl.c b/src/openssl.c
index 112a4c8..fa7dd79 100644
--- a/src/openssl.c
+++ b/src/openssl.c
@@ -2563,26 +2563,39 @@ static BIGNUM *(checkbig)(lua_State *L, int index, _Bool *lvalue) {
} /* checkbig() */
-static void bn_prepops(lua_State *L, BIGNUM **r, BIGNUM **a, BIGNUM **b, _Bool commute) {
+/* prepare number at top of stack for unary operation, and push result object onto stack */
+static void bn_prepuop(lua_State *L, BIGNUM **r, BIGNUM **a, _Bool commute) {
_Bool lvalue = 1;
- lua_settop(L, 2); /* a, b */
+ *a = checkbig(L, -1, &lvalue);
- *a = checkbig(L, 1, &lvalue);
+ if (!lvalue && commute) {
+ lua_pushvalue(L, -1);
+ } else {
+ bn_push(L);
+ }
+
+ *r = *(BIGNUM **)lua_touserdata(L, -1);
+} /* bn_prepuop() */
- if (!lvalue && commute)
- lua_pushvalue(L, 1);
- *b = checkbig(L, 2, &lvalue);
+/* prepare numbers at top of stack for binary operation, and push result object onto stack */
+static void bn_prepbop(lua_State *L, BIGNUM **r, BIGNUM **a, BIGNUM **b, _Bool commute) {
+ _Bool a_lvalue, b_lvalue;
- if (!lvalue && commute && lua_gettop(L) < 3)
- lua_pushvalue(L, 2);
+ *a = checkbig(L, -2, &a_lvalue);
+ *b = checkbig(L, -1, &b_lvalue);
- if (lua_gettop(L) < 3)
+ if (commute && !a_lvalue) {
+ lua_pushvalue(L, -2);
+ } else if (commute && !b_lvalue) {
+ lua_pushvalue(L, -1);
+ } else {
bn_push(L);
+ }
- *r = *(BIGNUM **)lua_touserdata(L, 3);
-} /* bn_prepops() */
+ *r = *(BIGNUM **)lua_touserdata(L, -1);
+} /* bn_prepbop() */
static int ctx__gc(lua_State *L) {
@@ -2639,7 +2652,8 @@ static int bn_toBinary(lua_State *L) {
static int bn__add(lua_State *L) {
BIGNUM *r, *a, *b;
- bn_prepops(L, &r, &a, &b, 1);
+ lua_settop(L, 2);
+ bn_prepbop(L, &r, &a, &b, 1);
if (!BN_add(r, a, b))
return auxL_error(L, auxL_EOPENSSL, "bignum:__add");
@@ -2651,7 +2665,8 @@ static int bn__add(lua_State *L) {
static int bn__sub(lua_State *L) {
BIGNUM *r, *a, *b;
- bn_prepops(L, &r, &a, &b, 0);
+ lua_settop(L, 2);
+ bn_prepbop(L, &r, &a, &b, 0);
if (!BN_sub(r, a, b))
return auxL_error(L, auxL_EOPENSSL, "bignum:__sub");
@@ -2663,7 +2678,8 @@ static int bn__sub(lua_State *L) {
static int bn__mul(lua_State *L) {
BIGNUM *r, *a, *b;
- bn_prepops(L, &r, &a, &b, 1);
+ lua_settop(L, 2);
+ bn_prepbop(L, &r, &a, &b, 1);
if (!BN_mul(r, a, b, getctx(L)))
return auxL_error(L, auxL_EOPENSSL, "bignum:__mul");
@@ -2675,7 +2691,8 @@ static int bn__mul(lua_State *L) {
static int bn_sqr(lua_State *L) {
BIGNUM *r, *a;
- bn_prepops(L, &r, &a, NULL, 1);
+ lua_settop(L, 1);
+ bn_prepuop(L, &r, &a, 1);
if (!BN_sqr(r, a, getctx(L)))
return auxL_error(L, auxL_EOPENSSL, "bignum:sqr");
@@ -2687,7 +2704,8 @@ static int bn_sqr(lua_State *L) {
static int bn__idiv(lua_State *L) {
BIGNUM *dv, *a, *b;
- bn_prepops(L, &dv, &a, &b, 0);
+ lua_settop(L, 2);
+ bn_prepbop(L, &dv, &a, &b, 0);
if (!BN_div(dv, NULL, a, b, getctx(L)))
return auxL_error(L, auxL_EOPENSSL, "bignum:__idiv");
@@ -2699,7 +2717,8 @@ static int bn__idiv(lua_State *L) {
static int bn__mod(lua_State *L) {
BIGNUM *r, *a, *b;
- bn_prepops(L, &r, &a, &b, 0);
+ lua_settop(L, 2);
+ bn_prepbop(L, &r, &a, &b, 0);
if (!BN_mod(r, a, b, getctx(L)))
return auxL_error(L, auxL_EOPENSSL, "bignum:__mod");
@@ -2717,7 +2736,8 @@ static int bn__mod(lua_State *L) {
static int bn_nnmod(lua_State *L) {
BIGNUM *r, *a, *b;
- bn_prepops(L, &r, &a, &b, 0);
+ lua_settop(L, 2);
+ bn_prepbop(L, &r, &a, &b, 0);
if (!BN_nnmod(r, a, b, getctx(L)))
return auxL_error(L, auxL_EOPENSSL, "bignum:nnmod");
@@ -2729,7 +2749,8 @@ static int bn_nnmod(lua_State *L) {
static int bn__pow(lua_State *L) {
BIGNUM *r, *a, *b;
- bn_prepops(L, &r, &a, &b, 0);
+ lua_settop(L, 2);
+ bn_prepbop(L, &r, &a, &b, 0);
if (!BN_exp(r, a, b, getctx(L)))
return auxL_error(L, auxL_EOPENSSL, "bignum:__pow");
@@ -2741,7 +2762,8 @@ static int bn__pow(lua_State *L) {
static int bn_gcd(lua_State *L) {
BIGNUM *r, *a, *b;
- bn_prepops(L, &r, &a, &b, 1);
+ lua_settop(L, 2);
+ bn_prepbop(L, &r, &a, &b, 1);
if (!BN_gcd(r, a, b, getctx(L)))
return auxL_error(L, auxL_EOPENSSL, "bignum:gcd");