/*
Validation routines
This file is in the public domain
*/

#include <iostream>
#include <fstream>
#include <cctype>
#include <string>
#include <vector>

#include <opencl/filters.h>
#include <opencl/encoder.h>
#include <opencl/pipe.h>
#include <opencl/rng.h>
using namespace OpenCL_types;

#define DEBUG 0

#ifndef DEBUG
   #define DEBUG 0
#endif

u32bit random_word(u32bit max)
   {
   u32bit r = 0;
   for(u32bit j = 0; j != 4; j++)
      r = (r << 8) | OpenCL::Global_RNG::random();
   return ((r % max) + 1); // return between 1 and max inclusive
   }

OpenCL::Filter* lookup(const std::string&, const std::string&,
                       const std::string&);

bool failed_test(const std::string&, const std::string&,
                 const std::string&, const std::string&,
                 const std::string&, bool, bool);

std::vector<std::string> parse(const std::string&);
void strip(std::string&);
byte* decode_hex(const std::string&);

u32bit do_validation_tests(const std::string& filename, bool should_pass)
   {
   std::ifstream test_data(filename.c_str());

   if(!test_data)
       {
       std::cout << "Couldn't open test file " << filename << std::endl;
       std::exit(1);
       }

   u32bit errors = 0, alg_count = 0;
   std::string algorithm;
   bool is_extension = false;

   while(!test_data.eof())
      {
      if(test_data.bad() || test_data.fail())
         {
         std::cout << "File I/O error." << std::endl;
         std::exit(1);
         }

      std::string line;
      std::getline(test_data, line);

      strip(line);
      if(line.size() == 0) continue;

      // Do line continuation
      while(line[line.size()-1] == '\\' && !test_data.eof())
         {
         line.replace(line.size()-1, 1, "");
         std::string nextline;
         std::getline(test_data, nextline);
         strip(nextline);
         if(nextline.size() == 0) continue;
         line += nextline;
         }

      if(line[0] == '[' && line[line.size() - 1] == ']')
         {
         const std::string ext_mark = "<EXTENSION>";
         algorithm = line.substr(1, line.size() - 2);
         is_extension = false;
         if(algorithm.find(ext_mark) != std::string::npos)
            {
            is_extension = true;
            algorithm.replace(algorithm.find(ext_mark),
                              ext_mark.length(), "");
            }
         alg_count = 0;
         continue;
         }

      std::vector<std::string> substr = parse(line);

      alg_count++;

      bool failed = failed_test(algorithm, substr[0], substr[1],
                                substr[2], substr[3], is_extension,
                                should_pass);

      if(failed && should_pass)
         {
         std::cout << "ERROR: \"" << algorithm << "\" failed test #"
                   << alg_count << std::endl;
         errors++;
         }

      if(!failed && !should_pass)
         {
         std::cout << "ERROR: \"" << algorithm << "\" passed test #"
                   << alg_count << " (unexpected pass)" << std::endl;
         errors++;
         }

      }
   return errors;
   }

bool failed_test(const std::string& algo,     const std::string& in,
                 const std::string& expected, const std::string& key,
                 const std::string& iv, bool is_extension, bool exp_pass)
   {
#if DEBUG
   std::cout << "Testing: " << algo;
   if(!exp_pass)
      std::cout << " (expecting failure)";
   std::cout << std::endl;
#endif

   OpenCL::Pipe pipe;

   try {

   OpenCL::Filter* test = lookup(algo, key, iv);
   if(test == 0 && is_extension) return !exp_pass;
   if(test == 0)
      {
      std::cout << "ERROR: \"" + algo + "\" is not a known algorithm name."
                << std::endl;
      std::exit(1);
      }

   pipe.append(test);
   pipe.append(new OpenCL::HexEncoder);

   pipe.start_msg();

   byte* data_ptr = decode_hex(in);
   byte* data = data_ptr;

   if(in.size() % 2 == 1)
      {
      std::cout << "Can't have an odd sized hex string!" << std::endl;
      std::exit(1);
      }

   // this can help catch errors with buffering, etc
   u32bit len = in.size() / 2;
   while(len)
      {
      u32bit how_much = random_word(len);
      pipe.write(data, how_much);
      data += how_much;
      len -= how_much;
      }
   pipe.end_msg();
   delete[] data_ptr;

   }
   catch(OpenCL::Exception& e)
      {
      if(exp_pass || DEBUG)
         std::cout << "Exception caught: " << e.what() << std::endl;
      return true;
      }
   catch(std::exception& e)
      {
      if(exp_pass || DEBUG)
         std::cout << "Standard library exception caught: "
                   << e.what() << std::endl;
      return true;
      }
   catch(...)
      {
      if(exp_pass || DEBUG)
         std::cout << "Unknown exception caught." << std::endl;
      return true;
      }

   std::string output = pipe.read_all_as_string();

   if(output == expected && !exp_pass)
      {
      std::cout << "FAILED: " << expected << " == " << std::endl
                << "        " << output << std::endl;
      return false;
      }

   if(output != expected && exp_pass)
      {
      std::cout << "FAILED: " << expected << " != " << std::endl
                << "        " << output << std::endl;
      return true;
      }

   if(output != expected && !exp_pass) return true;

   return false;
   }
