// Copyright 1996 Michael E. Stillman.

#include "det.hpp"
#include "text_io.hpp"
#include "bin_io.hpp"

extern int comp_printlevel;

DetComputation::DetComputation(const Matrix &M, int p,
			       bool do_exterior,
			       int strategy)
  : R(M.get_ring()),
    M(M),
    done(false),
    p(p),
    do_exterior(do_exterior),
    strategy(strategy),
    row_set(NULL),
    col_set(NULL),
    this_row(0),
    this_col(0),
    D(0)
{
  bump_up(R);

  if (do_exterior)
    {
      FreeModule *F = M.rows()->exterior(p);
      FreeModule *G = M.cols()->exterior(p);
      int *deg = R->degree_monoid()->make_new(M.degree_shift());
      R->degree_monoid()->power(deg, p, deg);
      result = Matrix(F,G,deg);
      R->degree_monoid()->remove(deg);
    }
  else
    {
      FreeModule *F = R->make_FreeModule(1);
      result = Matrix(F);
      // MES: do I need to bump down F??
    }

  if (do_trivial_case())
    {
      done = true;
      return;
    }

  row_set = new int[p];
  col_set = new int[p];

  for (int i=0; i<p; i++) 
    {
      row_set[i] = i;
      col_set[i] = i;
    }

  D = new ring_elem *[p];
  for (int i=0; i<p; i++)
    {
      D[i] = new ring_elem[p];
      for (int j=0; j<p;j++) D[i][j] = (Nterm *)0;
    }
}

DetComputation::~DetComputation()
{
  bump_down((Ring *)R);
  delete [] row_set;
  delete [] col_set;

  if (D)
    {
      for (int i=0; i<p; i++)
	delete [] D[i];
      delete [] D;
    }
}

int DetComputation::step()
     // Compute one more determinant of size p.
     // increments I and/or J and updates 'dets', 'table'.
{
  if (done) return COMP_DONE;

  ring_elem r;

  if (strategy == DET_BAREISS)
    {
      get_minor(row_set,col_set,p,D);
      r = bareiss_det();
    }
  else
    r = calc_det(row_set, col_set, p);

  if (!R->is_zero(r))
    {
      if (do_exterior)
	{
	  vec v = result.rows()->term(this_row,r);
	  result.rows()->add_to(result[this_col], v);
	}
      else
	result.append(result.rows()->term(0,r));
    }
  R->remove(r);

  this_row++;
  if (!comb::increment(p, M.n_rows(), row_set))
    {
      // Now we increment column
      if (!comb::increment(p, M.n_cols(), col_set))
	{
	  done = true;
	  return COMP_DONE;
	}
      // Now set the row set back to initial value
      this_col++;
      this_row = 0;
      for (int i=0; i<p; i++) row_set[i]=i;
    }
  return COMP_COMPUTING;
}

bool DetComputation::do_trivial_case()
{
  if (p < 0)
    {
      // In either case, want a zero matrix
      return true;
    }
  else if (p == 0)
    {
      // I suppose we want a single element which is '1'?
      return true;
    }
  else if (p > M.n_rows() || p > M.n_cols())
    {
      // Zero matrix in either case
      return true;
    }
  return false;
}

void DetComputation::clear()
{
  if (do_exterior) return;
  result = Matrix(result.rows());
}

void DetComputation::set_next_minor(const int *rows, const int *cols)
{
  if (do_exterior) return;
  int i;
  if (rows != NULL && comb::valid_subset(p, M.n_rows(), rows))
    for (i=0; i<p; i++) row_set[i] = rows[i];
  else
    for (i=0; i<p; i++)	row_set[i] = i;

  if (cols != NULL && comb::valid_subset(p, M.n_cols(), cols))
    for (i=0; i<p; i++) col_set[i] = cols[i];
  else
    for (i=0; i<p; i++)	col_set[i] = i;
}

int DetComputation::calc(int nsteps)
{
  for (;;)
    {
      int result = step();
      if (comp_printlevel >= 3)
	emit_wrapped(".");
      if (result == COMP_DONE)
	return COMP_DONE;
      if (--nsteps == 0)
	return COMP_DONE_STEPS;
      system_spincursor();
      if (system_interrupted)
	return COMP_INTERRUPTED;
    }
}

void DetComputation::get_minor(int *r, int *c, int p, ring_elem **D)
{
  for (int i=0; i<p; i++)
    for (int j=0; j<p; j++)
      D[i][j] = M.elem(r[i],c[j]);
}

bool DetComputation::get_pivot(ring_elem **D, int r, ring_elem &pivot, int &pivot_col)
  // Get a non-zero column 0..r in the r th row.
{
  // MES: it would be worthwhile to find a good pivot.
  for (int c=0; c<=r; c++)
    if (!R->is_zero(D[r][c]))
    {
      pivot_col = c;
      pivot = D[r][c];
      return true;
    }
  return false;
}

ring_elem DetComputation::detmult(ring_elem f1, ring_elem g1,
				  ring_elem f2, ring_elem g2,
				  ring_elem d)
{
  ring_elem a = R->mult(f1,g1);
  ring_elem b = R->mult(f2,g2);
  R->subtract_to(a,b);
  if (!R->is_zero(d))
    {
      ring_elem tmp = R->divide(a,d);
      R->remove(a);
      a = tmp;
    }
  R->remove(g1);
  return a;
}

void DetComputation::gauss(ring_elem **D, int i, int r, int pivot_col, ring_elem lastpivot)
{
  ring_elem f = D[i][pivot_col];
  ring_elem pivot = D[r][pivot_col];

  for (int c=0; c<pivot_col; c++)
    D[i][c] = detmult(pivot,D[i][c],f,D[r][c],lastpivot);

  for (int c=pivot_col+1; c<=r; c++)
    D[i][c-1] = detmult(pivot,D[i][c],f,D[r][c],lastpivot);

  R->remove(f);
}

ring_elem DetComputation::bareiss_det()
{
  // Computes the determinant of the p by p matrix D. (dense form).
  int sign = 1;
  int pivot_col;

  ring_elem pivot = R->from_int(0);
  ring_elem lastpivot = R->from_int(0);

  for (int r=p-1; r>=1; --r)
    {
      R->remove(lastpivot);
      lastpivot = pivot;
      if (!get_pivot(D, r, pivot, pivot_col)) // sets pivot_col and pivot
	{
	  // Remove the rest of D.
	  for (int i=0; i<=r; i++)
	    for (int j=0; j<=r; j++)
	      R->remove(D[i][j]);
	  R->remove(lastpivot);
	  return R->from_int(0);
	}
      for (int i=0; i<r; i++)
	gauss(D,i,r,pivot_col,lastpivot);

      if (((r + pivot_col) % 2) == 1)
	sign = -sign;  // MES: do I need to rethink this logic?

      for (int c=0; c<=r; c++)
	if (c != pivot_col)
	  R->remove(D[r][c]);
	else
	  D[r][c] = (Nterm *)0;
    }

  R->remove(pivot);
  R->remove(lastpivot);
  ring_elem result = D[0][0];
  D[0][0] = (Nterm*)0;

  if (sign < 0) R->negate_to(result);

  return result;
}
ring_elem DetComputation::calc_det(int *r, int *c, int p)
     // Compute the determinant of the minor with rows r[0]..r[p-1]
     // and columns c[0]..c[p-1].
{
//  int found;
//  const ring_elem &result = lookup(r,c,p,found);
//  if (found) return result;
  int i;
  if (p == 1) return M.elem(r[0],c[0]);
  ring_elem result = R->from_int(0);

  int negate = 1;
  for (i=p-1; i>=0; i--)
    {
#if 1
      swap(c[i],c[p-1]);
#else
      int tmp = c[i];
      c[i] = c[p-1];
      c[p-1] = tmp;
#endif
      negate = !negate;
      ring_elem g = M.elem(r[p-1],c[p-1]);
      if (R->is_zero(g)) 
	{
	  R->remove(g);
	  continue;
	}
      ring_elem h = calc_det(r,c,p-1);
      ring_elem gh = R->mult(g,h);
      R->remove(g);
      R->remove(h);
      if (negate)
	R->subtract_to(result, gh);
      else
	R->add_to(result, gh);
    }
  
  // pulling out the columns has disordered c. Fix it.
  
  int temp = c[p-1];
  for (i=p-1; i>0; i--)
    c[i] = c[i-1];
  c[0] = temp;

  return result;
}
