/* This file is in the public domain */
#include <iostream>
#include <fstream>
#include <cctype>
#include <string>
#include <vector>

#include <opencl/encoder.h>
#include <opencl/pubkey.h>
#include <opencl/pk_types.h>
#include <opencl/rng.h>

#include <opencl/randpool.h>

#include <opencl/rsa.h>

#include <opencl/sha1.h>
using namespace OpenCL;

class Fixed_Output_RNG : public RandomNumberGenerator
   {
   public:
      byte random()
         {
         if(position < output.size())
            return output[position++]; 
         else
            return 0;
         }
      void clear() throw() {}
      void add_entropy(const byte[], u32bit) throw() {}
      void add_entropy(EntropySource&, bool) {}
      Fixed_Output_RNG(const SecureVector<byte>& x) :
         RandomNumberGenerator("Fixed_Output_RNG")
         {
         output = x;
         position = 0;
         }
   private:
      SecureVector<byte> output;
      u32bit position;
   };

std::vector<std::string> parse(const std::string&);
void strip(std::string&);
BigInt string_to_bigint(const std::string&);

u32bit validate_rsa_enc(const std::string&, const std::vector<std::string>&);
u32bit validate_rsa_sig(const std::string&, const std::vector<std::string>&);

u32bit do_pk_validation_tests(const std::string& filename)
   {
   std::ifstream test_data(filename.c_str());

   if(!test_data)
       {
       std::cout << "Couldn't open test file " << filename << std::endl;
       std::exit(1);
       }

   u32bit errors = 0, alg_count = 0;
   std::string algorithm;

   while(!test_data.eof())
      {
      if(test_data.bad() || test_data.fail())
         {
         std::cout << "File I/O error." << std::endl;
         std::exit(1);
         }

      std::string line;
      std::getline(test_data, line);

      strip(line);
      if(line.size() == 0) continue;

      // Do line continuation
      while(line[line.size()-1] == '\\' && !test_data.eof())
         {
         line.replace(line.size()-1, 1, "");
         std::string nextline;
         std::getline(test_data, nextline);
         strip(nextline);
         if(nextline.size() == 0) continue;
         line += nextline;
         }

      if(line[0] == '[' && line[line.size() - 1] == ']')
         {
         algorithm = line.substr(1, line.size() - 2);
         alg_count = 0;
         continue;
         }

      std::vector<std::string> substr = parse(line);

      u32bit new_errors = 0;
      if(algorithm.find("RSAES_") != std::string::npos)
         new_errors = validate_rsa_enc(algorithm, substr);
      if(algorithm.find("RSASSA_") != std::string::npos)
         new_errors = validate_rsa_sig(algorithm, substr);
      alg_count++;
      errors += new_errors;

      if(new_errors)
         std::cout << "ERROR: \"" << algorithm << "\" failed test #"
                   << alg_count << std::endl;
      }
   // reset the global rng, because the validate_xxx will replace it
   set_global_rng(new Randpool);

   return errors;
   }

SecureVector<byte> string_to_bytes(const std::string& in)
   {
   OpenCL::Pipe pipe(new OpenCL::HexDecoder);
   pipe.start_msg();
   pipe.write(in);
   pipe.end_msg();

   return pipe.read_all();
   }

BigInt to_bigint(const std::string& hex_string)
   {
   BigInt retval = decode((const byte*)hex_string.data(), hex_string.length(),
                          Hexadecimal);
   return retval;
   }

u32bit validate_rsa_enc(const std::string& algo,
                        const std::vector<std::string>& str)
   {
   BigInt exponent = to_bigint(str[0]);

   RSA_PrivateKey key(to_bigint(str[1]), to_bigint(str[2]),
                      exponent.at(0));

   PK_Encryptor* e = 0;
   PK_Decryptor* d = 0;
   if(algo == "RSAES_EME1_SHA1")
      {
      e = new RSA_EME1_SHA1_Encryptor(key);
      d = new RSA_EME1_SHA1_Decryptor(key);
      }
   else if(algo == "RSAES_RAW")
      {
      e = new RSA_Raw_Encryptor(key);
      d = new RSA_Raw_Decryptor(key);
      }
   else
      {
      std::cout << "WARNING: pk.cpp doesn't know about " << algo << std::endl;
      return 0;
      }

   SecureVector<byte> message = string_to_bytes(str[3]);
   set_global_rng(new Fixed_Output_RNG(string_to_bytes(str[4])));

   SecureVector<byte> expected = string_to_bytes(str[5]);

   SecureVector<byte> out = e->encrypt(message, message.size());
   if(out != expected)
      {
      std::cout << "FAILED (e): " << algo << std::endl;
      return 1;
      }

   SecureVector<byte> decrypted = d->decrypt(out, out.size());
   if(decrypted != message)
      {
      std::cout << "FAILED (d): " << algo << std::endl;
      return 1;
      }

   delete e;
   delete d;
   return 0;
   }

u32bit validate_rsa_sig(const std::string& algo,
                        const std::vector<std::string>& str)
   {
   BigInt exponent = to_bigint(str[0]);

   RSA_PrivateKey key(to_bigint(str[1]), to_bigint(str[2]),
                      exponent.at(0));

   PK_Verifier* v = 0;
   PK_Signer* s = 0;
   if(algo == "RSASSA_EMSA2_SHA1")
      {
      v = new RSA_EMSA2_SHA1_Verifier(key);
      s = new RSA_EMSA2_SHA1_Signer(key);
      }
   else if(algo == "RSASSA_RAW_SHA1")
      {
      v = new RSA_Raw_SHA1_Verifier(key);
      s = new RSA_Raw_SHA1_Signer(key);
      }
   else
      {
      std::cout << "WARNING: pk.cpp doesn't know about " << algo << std::endl;
      return 0;
      }

   SecureVector<byte> message = string_to_bytes(str[3]);
   set_global_rng(new Fixed_Output_RNG(string_to_bytes(str[4])));

   SecureVector<byte> expected = string_to_bytes(str[5]);

   s->update(message, message.size());
   SecureVector<byte> out = s->signature();

   if(out != expected)
      {
      printf("out=");
      for(u32bit j = 0; j != out.size(); j++)
         printf("%02X ", out[j]);
      printf("\n");

      std::cout << "FAILED (s): " << algo << std::endl;
      return 1;
      }

   v->update(message, message.size());
   if(!v->valid_signature(out, out.size()))
      {
      std::cout << "FAILED (v): " << algo << std::endl;
      return 1;
      }

   delete v;
   delete s;
   return 0;
   }
