/*************************************************
* BigInt Base Source File                        *
* (C) 1999-2002 The OpenCL Project               *
*************************************************/

#include <opencl/bigint.h>
#include <opencl/numthry.h>

namespace OpenCL {

/*************************************************
* Construct a BigInt from a regular number       *
*************************************************/
BigInt::BigInt(u64bit n)
   {
   if(n <= 0xFFFFFFFF) reg.create(2);
   else                reg.create(4);
   reg[0] = (n & 0xFFFFFFFF);
   reg[1] = (n >> 32) & 0xFFFFFFFF;
   set_sign(Positive);
   }

/*************************************************
* Construct a BigInt of the specified size       *
*************************************************/
BigInt::BigInt(Sign s, u32bit size)
   {
   reg.create(size);
   signedness = s;
   }

/*************************************************
* Construct a BigInt from a "raw" BigInt         *
*************************************************/
BigInt::BigInt(const BigInt& b)
   {
   if(b.sig_words())
      {
      reg.resize_and_copy(b.reg.ptr(), b.sig_words());
      set_sign(b.sign());
      }
   else
      {
      reg.create(2);
      set_sign(Positive);
      }
   }

/*************************************************
* Construct a BigInt from a "raw" BigInt         *
*************************************************/
BigInt::BigInt(const u32bit* n, u32bit n_size, Sign s)
   {
   reg.resize_and_copy(n, n_size);
   set_sign(s);
   }

/*************************************************
* Construct a BigInt from a string               *
*************************************************/
BigInt::BigInt(const std::string& str)
   {
   Encoding code = Decimal;
   u32bit markers = 0;
   bool negative = false;
   if(str.length() > 0 && str[0] == '-')
      { markers = 1; negative = true; }
   if(str.length() > markers + 2 && str[markers    ] == '0' &&
                                    str[markers + 1] == 'x')
      { markers += 2; code = Hexadecimal; }
   *this = decode((byte*)str.data() + markers, str.length() - markers, code);
   if(negative) set_sign(Negative);
   else         set_sign(Positive);
   }

/*************************************************
* Construct a BigInt from an encoded BigInt      *
*************************************************/
BigInt::BigInt(const SecureVector<byte>& input, Encoding encoding)
   {
   *this = decode(input, encoding);
   set_sign(Positive);
   }

/*************************************************
* Construct a BigInt from an encoded BigInt      *
*************************************************/
BigInt::BigInt(const byte input[], u32bit length, Encoding encoding)
   {
   *this = decode(input, length, encoding);
   set_sign(Positive);
   }

/*************************************************
* Construct a Random BigInt                      *
*************************************************/
BigInt::BigInt(NumberType type, u32bit bits)
   {
   if(type == Random)
      *this = random_integer(bits);
   else if(type == Prime)
      *this = random_prime(bits);
   else if(type == SafePrime)
      *this = random_safe_prime(bits);
   }

/*************************************************
* Swap this BigInt with another                  *
*************************************************/
void BigInt::swap(BigInt& other)
   {
   std::swap(reg, other.reg);
   std::swap(signedness, other.signedness);
   }

/*************************************************
* Add n to this number                           *
*************************************************/
void BigInt::add(u32bit n)
   {
   if(!n) return;
   u32bit temp = reg[0];
   reg[0] += n;
   if(reg[0] > temp)
      return;
   for(u32bit j = 1; j != size(); j++)
      if(++reg[j]) return;
   reg.grow_to(2*size());
   reg[size() / 2] = 1;
   }

/*************************************************
* Subtract n from this number                    *
*************************************************/
void BigInt::sub(u32bit n)
   {
   if(!n) return;
   u32bit temp = reg[0];
   reg[0] -= n;
   if(reg[0] < temp)
      return;
   for(u32bit j = 1; j != size(); j++)
      if(reg[j]--) return;
   reg.create(2);
   flip_sign();
   reg[0] = n - temp;
   }

/*************************************************
* Return word n of this number                   *
*************************************************/
u32bit BigInt::at(u32bit n) const
   {
   if(n >= size()) return 0;
   else            return reg[n];
   }

/*************************************************
* Return bit n of this number                    *
*************************************************/
bool BigInt::get_bit(u32bit n) const
   {
   return ((at(n / 32) >> (n % 32)) & 1);
   }

/*************************************************
* Set bit number n                               *
*************************************************/
void BigInt::set_bit(u32bit n)
   {
   u32bit word = n / 32, mask = 1 << (n % 32);
   if(word >= size()) reg.grow_to(word + 1);
   reg[word] |= mask;
   }

/*************************************************
* Clear bit number n                             *
*************************************************/
void BigInt::clear_bit(u32bit n)
   {
   u32bit word = n / 32, mask = 1 << (n % 32);
   if(word < size())
      reg[word] &= ~mask;
   }

/*************************************************
* Count how many words are being used            *
*************************************************/
u32bit BigInt::sig_words() const
   {
   u32bit count = size();
   while(count && (reg[count-1] == 0))
      count--;
   return count;
   }

/*************************************************
* Count how many bytes are being used            *
*************************************************/
u32bit BigInt::bytes() const
   {
   u32bit bit_count = bits();
   return (bit_count / 8) + (bit_count % 8 ? 1 : 0);
   }

/*************************************************
* Count how many bits are being used             *
*************************************************/
u32bit BigInt::bits() const
   {
   if(sig_words() == 0) return 0;
   u32bit full_words = sig_words() - 1, top_bits = 32;
   u32bit top_word = at(full_words), mask = 0x80000000;
   while(top_bits && ((top_word & mask) == 0))
      { mask >>= 1; top_bits--; }
   return (full_words * 32 + top_bits);
   }

/*************************************************
* Return true if this number is zero             *
*************************************************/
bool BigInt::is_zero() const
   {
   for(u32bit j = 0; j != size(); j++)
      if(reg[j]) return false;
   return true;
   }

/*************************************************
* Set the sign                                   *
*************************************************/
void BigInt::set_sign(Sign s)
   {
   if(is_zero()) signedness = Positive;
   else signedness = s;
   }

/*************************************************
* Reverse the value of the sign flag             *
*************************************************/
void BigInt::flip_sign()
   {
   if(is_zero())
      signedness = Positive;
   else
      set_sign(reverse_sign());
   }

/*************************************************
* Return the opposite value of the current sign  *
*************************************************/
Sign BigInt::reverse_sign() const
   {
   if(sign() == Positive)
      return Negative;
   return Positive;
   }

/*************************************************
* Return the negation of this number             *
*************************************************/
BigInt BigInt::operator-() const
   {
   return (BigInt(reg, reg.size(), reverse_sign()));
   }

/*************************************************
* Return the absolute value of this number       *
*************************************************/
BigInt BigInt::abs() const
   {
   BigInt tmp = *this;
   tmp.set_sign(Positive);
   return tmp;
   }

/*************************************************
* Encode this number into bytes                  *
*************************************************/
void BigInt::binary_encode(byte output[]) const
   {
   const u32bit sig_bytes = bytes();
   for(u32bit j = sig_bytes; j != 0; j--)
      output[sig_bytes-j] = get_byte(4 - (j % 4), reg[(j - 1) / 4]);
   }

/*************************************************
* Encode this number into bytes                  *
*************************************************/
SecureVector<byte> BigInt::binary_encode() const
   {
   SecureVector<byte> output(bytes());
   binary_encode(output);
   return output;
   }

/*************************************************
* Set this number to the value in buf            *
*************************************************/
void BigInt::binary_decode(const SecureVector<byte>& buf)
   {
   binary_decode(buf, buf.size());
   }

/*************************************************
* Set this number to the value in buf            *
*************************************************/
void BigInt::binary_decode(const byte buf[], u32bit length)
   {
   reg.create(length / 4 + 1);
   for(u32bit j = 0; j != length / 4; j++)
      {
      u32bit top = length - 4*j;
      reg[j] = make_u32bit(buf[top - 4], buf[top - 3],
                           buf[top - 2], buf[top - 1]);
      }
   for(u32bit j = 0; j != length % 4; j++)
      reg[length / 4] = (reg[length / 4] << 8) + buf[j];
   }

}
