/*************************************************
* RSA Source File                                *
* (C) 1999-2002 The OpenCL Project               *
*************************************************/

#include <opencl/rsa.h>
#include <opencl/numthry.h>

namespace OpenCL {

/*************************************************
* RSA Encryption Function                        *
*************************************************/
SecureVector<byte> RSA_PublicKey::encrypt(const byte in[], u32bit len) const
   {
   temp.binary_decode(in, len);
   return encode(public_op(temp));
   }

/*************************************************
* RSA Verification Function                      *
*************************************************/
SecureVector<byte> RSA_PublicKey::verify(const byte in[], u32bit len) const
   {
   temp.binary_decode(in, len);
   return encode(public_op(temp));
   }

/*************************************************
* RSA Public Operation                           *
*************************************************/
BigInt RSA_PublicKey::public_op(const BigInt& i) const
   {
   if(i >= n || i.is_negative())
      throw Invalid_Argument("RSA::public_op: i >= n || i < 0");
   if(reducer_n == 0) reducer_n = get_reducer(n);
   return power_mod(i, e, reducer_n);
   }

/*************************************************
* RSA_PublicKey Destructor                       *
*************************************************/
RSA_PublicKey::~RSA_PublicKey()
   {
   delete reducer_n;
   }

/*************************************************
* Check RSA Parameters                           *
*************************************************/
void RSA_PrivateKey::check_params() const
   {
   if(n <= 0 || p <= 0 || q <= 0 || d <= 0)
      throw Invalid_Argument("RSA: Must have positive parameters");
   if(e < 3)
      throw Invalid_Argument("RSA: Encryption exponent too small");
   if(e % 2 == 0)
      throw Invalid_Argument("RSA: Encryption exponent is even");
   if(!is_prime(p) || !is_prime(q))
      throw Invalid_Argument("RSA: p and/or q is not prime");
   if(p * q != n)
      throw Invalid_Argument("RSA: p * q != n");
   if((e * d) % lcm(p - 1, q - 1) != 1)
      throw Invalid_Argument("RSA: Invalid exponent pair");
   }

/*************************************************
* RSA Precomputation                             *
*************************************************/
void RSA_PrivateKey::precompute()
   {
   d1 = d % (p - 1);
   d2 = d % (q - 1);
   c = inverse_mod(q, p);
   reducer_p = get_reducer(p);
   reducer_q = get_reducer(q);
   }

/*************************************************
* RSA Decryption Operation                       *
*************************************************/
SecureVector<byte> RSA_PrivateKey::decrypt(const byte in[], u32bit len) const
   {
   temp.binary_decode(in, len);
   return encode(private_op(temp));
   }

/*************************************************
* RSA Signature Operation                        *
*************************************************/
SecureVector<byte> RSA_PrivateKey::sign(const byte in[], u32bit len) const
   {
   temp.binary_decode(in, len);
   return encode(private_op(temp));
   }

/*************************************************
* RSA Private Key Operation                      *
*************************************************/
BigInt RSA_PrivateKey::private_op(const BigInt& i) const
   {
   if(i >= n || i < 0)
      throw Invalid_Argument("RSA::private_op: i >= n || i < 0");
   j1 = power_mod(reducer_p->reduce(i), d1, reducer_p);
   j2 = power_mod(reducer_q->reduce(i), d2, reducer_q);
   h = reducer_p->reduce(c * (j1 - j2));
   return (j2 + h * q);
   }

/*************************************************
* Create a RSA Private Key                       *
*************************************************/
RSA_PrivateKey::RSA_PrivateKey(const BigInt& prime1, const BigInt& prime2,
                               u32bit exponent, const BigInt& d_exponent,
                               const BigInt& modulus)
   {
   p = prime1;
   q = prime2;
   e = exponent;
   d = (d_exponent != 0) ? d_exponent : inverse_mod(e, lcm(p - 1, q - 1));
   n = (modulus != 0) ? modulus : p * q;
   precompute();
   }

/*************************************************
* Create a RSA Private Key                       *
*************************************************/
RSA_PrivateKey::RSA_PrivateKey(u32bit bits, u32bit exp)
   {
   if(e < 3)
      throw Invalid_Argument("RSA: Encryption exponent too small");
   if(e % 2 == 0)
      throw Invalid_Argument("RSA: Encryption exponent is even");
   e = exp;
   do
      p = random_prime((bits + 1) / 2);
   while(gcd(e, p - 1) != 1);
   do
      q = random_prime(bits - p.bits());
   while(gcd(e, q - 1) != 1);
   n = p * q;
   d = inverse_mod(e, lcm(p - 1, q - 1));
   precompute();
   }

/*************************************************
* RSA_PrivateKey Destructor                      *
*************************************************/
RSA_PrivateKey::~RSA_PrivateKey()
   {
   delete reducer_p;
   delete reducer_q;
   }

}
