www.usr.com/support/gpl/USR9107_release.1.4.tar.gz
[bcm963xx.git] / userapps / opensource / sshd / rsa.c
index b00b143..7248bed 100755 (executable)
@@ -29,7 +29,7 @@
  * Applied Cryptography detail the general algorithm. */
 
 #include "includes.h"
-#include "util.h"
+#include "dbutil.h"
 #include "bignum.h"
 #include "rsa.h"
 #include "buffer.h"
@@ -38,8 +38,9 @@
 
 #ifdef DROPBEAR_RSA 
 
-static mp_int * rsa_pad_em(rsa_key * key,
-               const unsigned char * data, unsigned int len);
+static void rsa_pad_em(rsa_key * key,
+               const unsigned char * data, unsigned int len,
+               mp_int * rsa_em);
 
 /* Load a public rsa key from a buffer, initialising the values.
  * The key will have the same format as buf_put_rsa_key.
@@ -47,7 +48,7 @@ static mp_int * rsa_pad_em(rsa_key * key,
  * Returns DROPBEAR_SUCCESS or DROPBEAR_FAILURE */
 int buf_get_rsa_pub_key(buffer* buf, rsa_key *key) {
 
-       TRACE(("enter buf_get_rsa_pub_key"));
+       TRACE(("enter buf_get_rsa_pub_key"))
        assert(key != NULL);
        key->e = m_malloc(sizeof(mp_int));
        key->n = m_malloc(sizeof(mp_int));
@@ -60,10 +61,16 @@ int buf_get_rsa_pub_key(buffer* buf, rsa_key *key) {
 
        if (buf_getmpint(buf, key->e) == DROPBEAR_FAILURE
         || buf_getmpint(buf, key->n) == DROPBEAR_FAILURE) {
-               TRACE(("leave buf_get_rsa_pub_key: failure"));
+               TRACE(("leave buf_get_rsa_pub_key: failure"))
                return DROPBEAR_FAILURE;
        }
-       TRACE(("leave buf_get_rsa_pub_key: success"));
+
+       if (mp_count_bits(key->n) < MIN_RSA_KEYLEN) {
+               dropbear_log(LOG_WARNING, "rsa key too short");
+               return DROPBEAR_FAILURE;
+       }
+
+       TRACE(("leave buf_get_rsa_pub_key: success"))
        return DROPBEAR_SUCCESS;
 
 }
@@ -75,17 +82,17 @@ int buf_get_rsa_priv_key(buffer* buf, rsa_key *key) {
 
        assert(key != NULL);
 
-       TRACE(("enter buf_get_rsa_priv_key"));
+       TRACE(("enter buf_get_rsa_priv_key"))
 
        if (buf_get_rsa_pub_key(buf, key) == DROPBEAR_FAILURE) {
-               TRACE(("leave buf_get_rsa_priv_key: pub: ret == DROPBEAR_FAILURE"));
+               TRACE(("leave buf_get_rsa_priv_key: pub: ret == DROPBEAR_FAILURE"))
                return DROPBEAR_FAILURE;
        }
 
        key->d = m_malloc(sizeof(mp_int));
        m_mp_init(key->d);
        if (buf_getmpint(buf, key->d) == DROPBEAR_FAILURE) {
-               TRACE(("leave buf_get_rsa_priv_key: d: ret == DROPBEAR_FAILURE"));
+               TRACE(("leave buf_get_rsa_priv_key: d: ret == DROPBEAR_FAILURE"))
                return DROPBEAR_FAILURE;
        }
 
@@ -99,17 +106,17 @@ int buf_get_rsa_priv_key(buffer* buf, rsa_key *key) {
                m_mp_init_multi(key->p, key->q, NULL);
 
                if (buf_getmpint(buf, key->p) == DROPBEAR_FAILURE) {
-                       TRACE(("leave buf_get_rsa_priv_key: p: ret == DROPBEAR_FAILURE"));
+                       TRACE(("leave buf_get_rsa_priv_key: p: ret == DROPBEAR_FAILURE"))
                        return DROPBEAR_FAILURE;
                }
 
                if (buf_getmpint(buf, key->q) == DROPBEAR_FAILURE) {
-                       TRACE(("leave buf_get_rsa_priv_key: q: ret == DROPBEAR_FAILURE"));
+                       TRACE(("leave buf_get_rsa_priv_key: q: ret == DROPBEAR_FAILURE"))
                        return DROPBEAR_FAILURE;
                }
        }
 
-       TRACE(("leave buf_get_rsa_priv_key"));
+       TRACE(("leave buf_get_rsa_priv_key"))
        return DROPBEAR_SUCCESS;
 }
        
@@ -117,10 +124,10 @@ int buf_get_rsa_priv_key(buffer* buf, rsa_key *key) {
 /* Clear and free the memory used by a public or private key */
 void rsa_key_free(rsa_key *key) {
 
-       TRACE(("enter rsa_key_free"));
+       TRACE(("enter rsa_key_free"))
 
        if (key == NULL) {
-               TRACE(("leave rsa_key_free: key == NULL"));
+               TRACE(("leave rsa_key_free: key == NULL"))
                return;
        }
        if (key->d) {
@@ -144,7 +151,7 @@ void rsa_key_free(rsa_key *key) {
                m_free(key->q);
        }
        m_free(key);
-       TRACE(("leave rsa_key_free"));
+       TRACE(("leave rsa_key_free"))
 }
 
 /* Put the public rsa key into the buffer in the required format:
@@ -155,21 +162,21 @@ void rsa_key_free(rsa_key *key) {
  */
 void buf_put_rsa_pub_key(buffer* buf, rsa_key *key) {
 
-       TRACE(("enter buf_put_rsa_pub_key"));
+       TRACE(("enter buf_put_rsa_pub_key"))
        assert(key != NULL);
 
        buf_putstring(buf, SSH_SIGNKEY_RSA, SSH_SIGNKEY_RSA_LEN);
        buf_putmpint(buf, key->e);
        buf_putmpint(buf, key->n);
 
-       TRACE(("leave buf_put_rsa_pub_key"));
+       TRACE(("leave buf_put_rsa_pub_key"))
 
 }
 
 /* Same as buf_put_rsa_pub_key, but with the private "x" key appended */
 void buf_put_rsa_priv_key(buffer* buf, rsa_key *key) {
 
-       TRACE(("enter buf_put_rsa_priv_key"));
+       TRACE(("enter buf_put_rsa_priv_key"))
 
        assert(key != NULL);
        buf_put_rsa_pub_key(buf, key);
@@ -184,7 +191,7 @@ void buf_put_rsa_priv_key(buffer* buf, rsa_key *key) {
        }
 
 
-       TRACE(("leave buf_put_rsa_priv_key"));
+       TRACE(("leave buf_put_rsa_priv_key"))
 
 }
 
@@ -195,43 +202,55 @@ int buf_rsa_verify(buffer * buf, rsa_key *key, const unsigned char* data,
                unsigned int len) {
 
        unsigned int slen;
-       mp_int rsa_s, rsa_mdash;
-       mp_int *rsa_em = NULL;
+       DEF_MP_INT(rsa_s);
+       DEF_MP_INT(rsa_mdash);
+       DEF_MP_INT(rsa_em);
        int ret = DROPBEAR_FAILURE;
 
+       TRACE(("enter buf_rsa_verify"))
+
        assert(key != NULL);
 
-       m_mp_init_multi(&rsa_mdash, &rsa_s, NULL);
+       m_mp_init_multi(&rsa_mdash, &rsa_s, &rsa_em, NULL);
 
        slen = buf_getint(buf);
        if (slen != (unsigned int)mp_unsigned_bin_size(key->n)) {
-               TRACE(("bad size"));
+               TRACE(("bad size"))
                goto out;
        }
 
        if (mp_read_unsigned_bin(&rsa_s, buf_getptr(buf, buf->len - buf->pos),
                                buf->len - buf->pos) != MP_OKAY) {
+               TRACE(("failed reading rsa_s"))
+               goto out;
+       }
+
+       /* check that s <= n-1 */
+       if (mp_cmp(&rsa_s, key->n) != MP_LT) {
+               TRACE(("s > n-1"))
                goto out;
        }
 
        /* create the magic PKCS padded value */
-       rsa_em = rsa_pad_em(key, data, len);
+       rsa_pad_em(key, data, len, &rsa_em);
 
        if (mp_exptmod(&rsa_s, key->e, key->n, &rsa_mdash) != MP_OKAY) {
+               TRACE(("failed exptmod rsa_s"))
                goto out;
        }
 
-       if (mp_cmp(rsa_em, &rsa_mdash) == MP_EQ) {
+       if (mp_cmp(&rsa_em, &rsa_mdash) == MP_EQ) {
                /* signature is valid */
+               TRACE(("success!"))
                ret = DROPBEAR_SUCCESS;
        }
 
 out:
-       mp_clear_multi(rsa_em, &rsa_mdash, &rsa_s, NULL);
-       m_free(rsa_em);
+       mp_clear_multi(&rsa_mdash, &rsa_s, &rsa_em, NULL);
+       TRACE(("leave buf_rsa_verify: ret %d", ret))
        return ret;
-
 }
+
 #endif /* DROPBEAR_SIGNKEY_VERIFY */
 
 /* Sign the data presented with key, writing the signature contents
@@ -241,22 +260,56 @@ void buf_put_rsa_sign(buffer* buf, rsa_key *key, const unsigned char* data,
 
        unsigned int nsize, ssize;
        unsigned int i;
-       mp_int rsa_s;
-       mp_int *rsa_em;
+       DEF_MP_INT(rsa_s);
+       DEF_MP_INT(rsa_tmp1);
+       DEF_MP_INT(rsa_tmp2);
+       DEF_MP_INT(rsa_tmp3);
+       unsigned char *tmpbuf;
        
-       TRACE(("enter buf_put_rsa_sign"));
+       TRACE(("enter buf_put_rsa_sign"))
        assert(key != NULL);
 
-       rsa_em = rsa_pad_em(key, data, len);
+       m_mp_init_multi(&rsa_s, &rsa_tmp1, &rsa_tmp2, &rsa_tmp3, NULL);
+
+       rsa_pad_em(key, data, len, &rsa_tmp1);
 
        /* the actual signing of the padded data */
-       m_mp_init(&rsa_s);
+
+#ifdef RSA_BLINDING
+
+       /* With blinding, s = (r^(-1))((em)*r^e)^d mod n */
+
+       /* generate the r blinding value */
+       /* rsa_tmp2 is r */
+       gen_random_mpint(key->n, &rsa_tmp2);
+
+       /* rsa_tmp1 is em */
+       /* em' = em * r^e mod n */
+
+       mp_exptmod(&rsa_tmp2, key->e, key->n, &rsa_s); /* rsa_s used as a temp var*/
+       mp_invmod(&rsa_tmp2, key->n, &rsa_tmp3);
+       mp_mulmod(&rsa_tmp1, &rsa_s, key->n, &rsa_tmp2);
+
+       /* rsa_tmp2 is em' */
+       /* s' = (em')^d mod n */
+       mp_exptmod(&rsa_tmp2, key->d, key->n, &rsa_tmp1);
+
+       /* rsa_tmp1 is s' */
+       /* rsa_tmp3 is r^(-1) mod n */
+       /* s = (s')r^(-1) mod n */
+       mp_mulmod(&rsa_tmp1, &rsa_tmp3, key->n, &rsa_s);
+
+#else
+
        /* s = em^d mod n */
-       if (mp_exptmod(rsa_em, key->d, key->n, &rsa_s) != MP_OKAY) {
+       /* rsa_tmp1 is em */
+       if (mp_exptmod(&rsa_tmp1, key->d, key->n, &rsa_s) != MP_OKAY) {
                dropbear_exit("rsa error");
        }
-       mp_clear(rsa_em);
-       m_free(rsa_em);
+
+#endif /* RSA_BLINDING */
+
+       mp_clear_multi(&rsa_tmp1, &rsa_tmp2, &rsa_tmp3, NULL);
        
        /* create the signature to return */
        buf_putstring(buf, SSH_SIGNKEY_RSA, SSH_SIGNKEY_RSA_LEN);
@@ -279,11 +332,11 @@ void buf_put_rsa_sign(buffer* buf, rsa_key *key, const unsigned char* data,
        mp_clear(&rsa_s);
 
 #if defined(DEBUG_RSA) && defined(DEBUG_TRACE)
-       printhex(buf->data, buf->len);
+       printhex("RSA sig", buf->data, buf->len);
 #endif
        
 
-       TRACE(("leave buf_put_rsa_sign"));
+       TRACE(("leave buf_put_rsa_sign"))
 }
 
 /* Creates the message value as expected by PKCS, see rfc2437 etc */
@@ -295,19 +348,22 @@ void buf_put_rsa_sign(buffer* buf, rsa_key *key, const unsigned char* data,
  *
  * prefix is the ASN1 designator prefix,
  * hex 30 21 30 09 06 05 2B 0E 03 02 1A 05 00 04 14
+ *
+ * rsa_em must be a pointer to an initialised mp_int.
  */
-static mp_int * rsa_pad_em(rsa_key * key,
-               const unsigned char * data, unsigned int len) {
+static void rsa_pad_em(rsa_key * key,
+               const unsigned char * data, unsigned int len, 
+               mp_int * rsa_em) {
 
        /* ASN1 designator (including the 0x00 preceding) */
-       const char rsa_asn1_magic[] = 
+       const unsigned char rsa_asn1_magic[] = 
                {0x00, 0x30, 0x21, 0x30, 0x09, 0x06, 0x05, 0x2b, 
                 0x0e, 0x03, 0x02, 0x1a, 0x05, 0x00, 0x04, 0x14};
-#define RSA_ASN1_MAGIC_LEN 16
-       buffer * rsa_EM;
+       const unsigned int RSA_ASN1_MAGIC_LEN = 16;
+
+       buffer * rsa_EM = NULL;
        hash_state hs;
        unsigned int nsize;
-       mp_int * rsa_em;
        
        assert(key != NULL);
        assert(data != NULL);
@@ -335,16 +391,9 @@ static mp_int * rsa_pad_em(rsa_key * key,
 
        /* Create the mp_int from the encoded bytes */
        buf_setpos(rsa_EM, 0);
-       rsa_em = (mp_int*)m_malloc(sizeof(mp_int));
-       m_mp_init(rsa_em);
-       if (mp_read_unsigned_bin(rsa_em, buf_getptr(rsa_EM, rsa_EM->size),
-                               rsa_EM->size) != MP_OKAY) {
-               dropbear_exit("rsa error");
-       }
+       bytes_to_mp(rsa_em, buf_getptr(rsa_EM, rsa_EM->size),
+                       rsa_EM->size);
        buf_free(rsa_EM);
-
-       return rsa_em;
-
 }
 
 #endif /* DROPBEAR_RSA */