/*************************************************
* Number Theory Source File                      *
* (C) 1999-2002 The OpenCL Project               *
*************************************************/

#include <opencl/primes.h>
#include <opencl/numthry.h>

namespace OpenCL {

namespace {

u32bit low_zero_bits(const BigInt& n)
   {
   if(n.is_zero()) return 0;

   u32bit bits = 0, max_bits = n.bits();
   while((n.get_bit(bits) == 0) && bits < max_bits) bits++;
   return bits;
   }

u32bit miller_rabin_test_iterations(u32bit bits)
   {
   struct mapping { u32bit bits; u32bit iterations; };

   static const mapping tests[] = {
      {  50,  35 },
      { 100,  30 },
      { 150,  25 },
      { 200,  20 },
      { 300,  15 },
      { 400,  10 },
      { 500,   8 },
      { 600,   7 },
      { 700,   5 },
      { 800,   4 },
      { 1250,  3 },
      { 0, 0 }
   };

   for(u32bit j = 0; tests[j].bits; j++)
      {
      if(bits <= tests[j].bits)
         return tests[j].iterations;
      }
   return 2;
   }

}

/*************************************************
* Square a BigInt                                *
*************************************************/
BigInt square(const BigInt& a)
   {
   return (a * a);
   }

/*************************************************
* Calculate the GCD                              *
*************************************************/
BigInt gcd(const BigInt& a, const BigInt& b)
   {
   if(a.is_zero() || b.is_zero()) return BigInt::zero();
   BigInt g = BigInt::one(), x = a, y = b;
   x.set_sign(Positive);
   y.set_sign(Positive);

   while(x.is_even() && y.is_even() && x.is_nonzero() && y.is_nonzero())
      { x >>= 1; y >>= 1; g <<= 1; }
   while(x.is_nonzero())
      {
      x >>= low_zero_bits(x);
      y >>= low_zero_bits(y);
      if(x >= y) x = (x - y) >> 1;
      else       y = (y - x) >> 1;
      }
   return (g * y);
   }

/*************************************************
* Calculate the LCM                              *
*************************************************/
BigInt lcm(const BigInt& a, const BigInt& b)
   {
   return ((a * b) / gcd(a, b));
   }

/*************************************************
* Exponentiation                                 *
*************************************************/
BigInt power(const BigInt& base, u32bit power)
   {
   BigInt x = BigInt::one(), a = base;
   while(power != 0)
      {
      if(power % 2)
         x *= a;
      power >>= 1;
      if(power != 0)
         a = square(a);
      }
   return x;
   }

/*************************************************
* Modular Exponentiation                         *
*************************************************/
BigInt power_mod(const BigInt& base, const BigInt& power, const BigInt& mod)
   {
   ModularReducer* reducer = get_reducer(mod);
   BigInt x = power_mod(base, power, reducer);
   delete reducer;
   return x;
   }

/*************************************************
* Modular Exponentiation                         *
*************************************************/
BigInt power_mod(const BigInt& base, const BigInt& power,
                 ModularReducer* reducer)
   {
   if(power.is_negative())
      throw Invalid_Argument("power_mod: power must be positive");

   BigInt x = BigInt::one(), a = base;
   u32bit power_bits = power.bits();

   for(u32bit j = 0; j != power_bits; j++)
      {
      if(power.get_bit(j))
         x = reducer->multiply(a, x);
      a = reducer->multiply(a, a);
      }
   return x;
   }

/*************************************************
* Find the Modular Inverse                       *
*************************************************/
BigInt inverse_mod(const BigInt& n, const BigInt& mod)
   {
   if(mod.is_zero()) throw BigInt::DivideByZero();
   if(mod.is_negative() || n.is_negative())
      throw Invalid_Argument("inverse_mod: arguments must be non-negative");
   if(n.is_even() && mod.is_even())
      return BigInt::zero();

   BigInt x = mod, y = n, u = mod, v = n;
   BigInt A = BigInt::one(),  B = BigInt::zero(),
          C = BigInt::zero(), D = BigInt::one();

   while(u.is_nonzero())
      {
      while(u.is_even() && u.is_nonzero())
         {
         u >>= 1;
         if(A.is_odd() || B.is_odd()) { A += y; B -= x; }
         A >>= 1; B >>= 1;
         }

      while(v.is_even() && v.is_nonzero())
         {
         v >>= 1;
         if(C.is_odd() || D.is_odd()) { C += y; D -= x; }
         C >>= 1; D >>= 1;
         }

      if(u >= v) { u -= v; A -= C; B -= D; }
      else       { v -= u; C -= A; D -= B; }
      }

   if(v != BigInt::one())
      return BigInt::zero();

   while(D.is_negative()) D += mod;
   while(D >= mod) D -= mod;

   return D;
   }

/*************************************************
* Check for primality                            *
*************************************************/
bool is_prime(const BigInt& n)
   {
   if(n == BigInt::zero() || n == BigInt::one() || n.is_even())
      return false;
   for(u32bit j = 1; j != 200; j++)
      {
      if(n <= PRIMES[j])         return true;
      if(gcd(n, PRIMES[j]) != 1) return false;
      }
   BigInt m = n - BigInt::one();
   u32bit k = low_zero_bits(m);
   m >>= k;

   u32bit tests = miller_rabin_test_iterations(n.bits());

   ModularReducer* reducer = get_reducer(n);
   for(u32bit j = 0; j != tests && j != PRIME_TABLE_SIZE; j++)
      if(!strong_probable_prime(n, m, k, PRIMES[j], reducer))
         { delete reducer; return false; }
   delete reducer;
   return true;
   }

/*************************************************
* Miller-Rabin Strong Prime Test                 *
*************************************************/
bool strong_probable_prime(const BigInt& n, const BigInt& m, u32bit k,
                           const BigInt& a, ModularReducer* reducer)
   {
   if(a < 2 || a > n-BigInt::one())
      throw Invalid_Argument("Bad size for 'a' in Miller-Rabin test");
   BigInt b = power_mod(a, m, reducer), n_minus_1 = n - BigInt::one();
   if(b == BigInt::one() || b == n_minus_1) return true;
   for(u32bit j = 0; j != k; j++)
      {
      if(b == n_minus_1)     return true;
      if(b == BigInt::one()) return false;
      b = reducer->multiply(b, b);
      }
   return false;
   }

}
