/****************************************************************************
 *
 * DFT++:  density functional package developed by
 *         the research group of Prof. Tomas Arias, MIT.
 *
 * Principal author: Sohrab Ismail-Beigi
 *
 * Modifications for MPI version: Kenneth P Esler,
 *                                Sohrab Ismail-Beigi, and
 *                                Tairan Wang.
 *
 * Modifications for LSD version: Jason A Cline
 *
 * Modifications for lattice/Pulay forces: Gabor Csanyi and
 *                                         Sohrab Ismail-Beigi
 *
 * Copyright (C) 1996-1998 The Massachusetts Institute of Technology (MIT).
 *
 ****************************************************************************/

/* $Id: Basis.c,v 1.2 1999/12/03 20:57:33 tairan Exp $ */

#include "header.h"
#include <stdio.h>
#include <math.h>

/* 
 * Implementing the member functions of Basis
 */

/*
 * First stage Basis initialization.
 * Set up crystal structure parameters.
 */
void
Basis::init(real Ecut, int Nx_spec, int Ny_spec, int Nz_spec)
{
  int i, j;
  matrix3 invGGT,box;
  vector3 e,f;
  int ibox[3],fftbox[3];

  dft_log("\n----- setup_basis() -- Part I ---\n");

  // Calculate unit cell volume
  unit_cell_volume = fabs(det3(latvec));

  // Calculate invR
  invR = inv3(latvec);

  // Calculate recip. lattice vectors and dot products
  G = (2.0*M_PI)* invR;
  GGT = G*(~G);
  dft_log("latvec =\n");
  latvec.print(dft_global_log,"%10lg ");
  dft_log("unit cell volume = %lg\n\n",unit_cell_volume);
  dft_log("G =\n");
  G.print(dft_global_log, "%10lg ");
  dft_log("GGT =\n");
  GGT.print(dft_global_log, "%10lg ");
  
  /* We want to know for what vectors v lying on the constant energy surface
   * Ecut = 0.5*(v,GGT*v) the x, y, or z component/projection of v is maximal.
   * This is easy:  the gradient at that point v must lie in the x, y, or z
   * direction!  I.e. if e is the unit direction we are interested in,
   * GGT*v = mu*e or v = mu*inv(GGT)*e, where mu is chosen to be
   * mu = sqrt(2*Ecut/(e,inv(GGT)*e) to ensure v is on the const. energy
   * surface. We solve this for e being x,y,z unit vectors and put the
   * resulting vectors into the rows of box. */
  invGGT = inv3(GGT);
  for (i=0; i < 3; i++)
    {
      for (j=0; j < 3; j++)
	e.v[j] = 0.0;
      e.v[i] = 1.0;
      f = invGGT*e;
      f = sqrt(2.0*Ecut/(e*f))*f;
      for (j=0; j < 3; j++)
	box.m[j][i] = f.v[j];
    }

  dft_log("\nEnergy cutoff Ecut = %lg Hartrees\n",Ecut);
  dft_log("On the surface Ecut = 0.5*|G|^2, the vector extremizing\n");
  dft_log("(G,e_x) = "); box[0].print(dft_global_log,"%lg ");
  dft_log("(G,e_y) = "); box[1].print(dft_global_log,"%lg ");
  dft_log("(G,e_z) = "); box[2].print(dft_global_log,"%lg ");

  /* Truncate the values on the diagonal of box to integers.  This will
   * be the size of a box along x/y/z which will contain all G-vectors
   * of energy < Ecut. */
  for (i=0; i < 3; i++)
    ibox[i] = (int)fabs(box.m[i][i]);
  Gxmax = ibox[0]; Gxmin = -ibox[0];
  Gymax = ibox[1]; Gymin = -ibox[1];
  Gzmax = ibox[2]; Gzmin = -ibox[2];

  dft_log("\nSize of box containing G-vectors = ");
  dft_log("[-%d,%d] by [-%d,%d] by [-%d,%d]\n",
	    ibox[0],ibox[0],ibox[1],ibox[1],ibox[2],ibox[2]);
  dft_log("Size of box containing density = ");
  dft_log("[-%d,%d] by [-%d,%d] by [-%d,%d]\n\n",
	    2*ibox[0],2*ibox[0],2*ibox[1],2*ibox[1],2*ibox[2],2*ibox[2]);

  /* Find the minimal FFT box size the factors into the primes (2,3,5,7).
   * The minimum value for the size of the fftbox is 2*2*ibox+1 because
   * we square the wave-functions and so the FFT box G-vectors must
   * at least range over -2*ibox to 2*ibox inclusive.
   * However, various other routines require the FFT box size to be an
   * even integer, so we start the fftbox-size at 4*ibox+2.
   * The loop tries to factorize the fftbox-size into (2,3,5,7)...if that
   * isn't doable, it increases the fftbox-size by 2 and tries again. */

  for (i=0; i < 3; i++)
    {
      int b,n2,n3,n5,n7,done_factoring;

      fftbox[i] = 4*ibox[i]+2  -2;
      /* increase fftbox[i] by 2 and try to factor it into (2,3,5,7) */
      do
	{
	  fftbox[i] += 2;
	  b = fftbox[i];
	  n2 = n3 = n5 = n7 = done_factoring = 0;
	  while (!done_factoring)
	    {
	      if (b%2==0) { n2++; b /= 2; continue; }
	      if (b%3==0) { n3++; b /= 3; continue; }
	      if (b%5==0) { n5++; b /= 5; continue; }
	      if (b%7==0) { n7++; b /= 7; continue; }
	      done_factoring = 1;
	    }
	}
      while (b != 1); /*  b==1 means fftbox[i] is (2,3,5,7) factorizable */
      dft_log("fftbox[%d] = %d =(2^%d)*(3^%d)*(5^%d)*(7^%d)\n",
		i,fftbox[i],n2,n3,n5,n7);
    }

  /* original radix 2 FFT boxsize finder, i.e. rounds up 4*ibox+1 to the
   * next power of 2 */
/*   for (i=0; i < 3; i++) */
/*     for (k=1;;) */
/*       { */
/* 	if (4*ibox[i]+1 > k) */
/* 	  k <<= 1; */
/* 	else */
/* 	  { */
/* 	    fftbox[i] = k; */
/* 	    break; */
/* 	  } */
/*       } */

  /* If we're given FFT box sizes to use, then use them! */
  if (Nx_spec != 0 || Ny_spec != 0 || Nz_spec !=0)
    {
      fftbox[0] = Nx_spec;
      fftbox[1] = Ny_spec;
      fftbox[2] = Nz_spec;
      dft_log("==============================================\n");
      dft_log("Overiding with specified sizes: %d by %d by %d\n",
		Nx_spec,Ny_spec,Nz_spec);
      dft_log("==============================================\n");
    }

  for (i=0; i < 3; i++)
    if (fftbox[i] < 4*ibox[i]+2)
      die("\nsetup_basis():  fftbox[%d] is too small.\nIt should be AT LEAST %d = 4*%d+2\n\n"
	  , i, 4*ibox[i]+2,ibox[i] );
  Nx = fftbox[0];
  Ny = fftbox[1];
  Nz = fftbox[2];
  NxNyNz = Nx * Ny * Nz;

  dft_log("Using fftbox = %d by %d by %d\n",
	    fftbox[0],fftbox[1],fftbox[2]);


  // nullify all pointers for now.
  nbasis = 0;
  Gx = Gy = Gz = NULL;
  index = NULL;
}


/*
 * Second part of the initialization:
 * setting up the g-vectors and index-matrix.
 */
void
Basis::init2(const Basis &b, const vector3 &kvec, 
		  real Ecut, real &G2max)
{
  // Copy over the basic stuff.
  latvec           = b.latvec;
  invR             = b.invR;
  G                = b.G;
  GGT              = b.GGT;
  unit_cell_volume = b.unit_cell_volume;
  Nx               = b.Nx;
  Ny               = b.Ny;
  Nz               = b.Nz;
  NxNyNz           = b.NxNyNz;
  Gxmin            = b.Gxmin;
  Gxmax            = b.Gxmax;
  Gymin            = b.Gymin;
  Gymax            = b.Gymax;
  Gzmin            = b.Gzmin;
  Gzmax            = b.Gzmax;


  if (Ecut <= 0.0) {
    // don't calculate mappings.
    nbasis = 0;
    Gx = Gy = Gz = NULL;
    index = NULL;
    return;
  }

  /* Figure out the G-vectors within the cutoff energy.
   * First count up how many there are, then put them in the basis struct.
   * In the continuum limit, there should be (4*pi/3)*(2*Ecut)^(3/2)/det(G)
   * points:  divide the volume of the maximum energy sphere in G-space by
   * the volume of the primitive cell in G-space. */
  real G2;
  int i, j, k, n, ind, ibox[3];
  vector3 f;

  ibox[0] = (Nx - 2)/4 + 1;
  ibox[1] = (Ny - 2)/4 + 1;
  ibox[2] = (Nz - 2)/4 + 1;

  nbasis = 0;
  for (i = -ibox[0]; i <= ibox[0]; i++)
    for (j = -ibox[1]; j <= ibox[1]; j++)
      for (k = -ibox[2]; k <= ibox[2]; k++) {
	  f.v[0] = i;
	  f.v[1] = j;
	  f.v[2] = k;
	  f += kvec;
	  G2 = f*(GGT*f);
	  if (0.5*G2 <= Ecut)
	    nbasis++;
      }

  dft_log("nbasis = %d for k = [%6.3f %6.3f %6.3f]\n",
	    nbasis, kvec.v[0],kvec.v[1],kvec.v[2]);
  
  Gx = (int *)mymalloc(sizeof(int)*nbasis,"Gx","Basis::init2()");
  Gy = (int *)mymalloc(sizeof(int)*nbasis,"Gy","Basis::init2()");
  Gz = (int *)mymalloc(sizeof(int)*nbasis,"Gz","Basis::init2()");
  index = (int *)mymalloc(sizeof(int)*nbasis,"index","Basis::init2()");
  n = 0;
  for (i = -ibox[0]; i <= ibox[0]; i++)
    for (j = -ibox[1]; j <= ibox[1]; j++)
      for (k = -ibox[2]; k <= ibox[2]; k++)
	{
	  f.v[0] = i;
	  f.v[1] = j;
	  f.v[2] = k;
	  f += kvec;
	  G2 = f*(GGT*f);
	  if (0.5*G2 <= Ecut)
	    {
	      if (G2 > G2max)
		G2max = G2;
	      Gx[n] = i;
	      Gy[n] = j;
	      Gz[n] = k;
	      ind = 0;
	      if (k >= 0) ind += k;
	      else        ind += k+ Nz;
	      if (j >= 0) ind += Nz*j;
	      else        ind += Nz*(j+Ny);
	      if (i >= 0) ind += Nz*Ny*i;
	      else        ind += Nz*Ny*(Nx+i);
	      index[n] = ind;
	      dft_log(DFT_ANAL_LOG, "G = [ %3d %3d %3d ]  index = %7d  G2 = %lg\n",
			Gx[n], Gy[n], Gz[n], index[n], G2);
	      n++;
	    }
	}
}


/*
 * alternative second part of the initialization:
 *
 * Pointing the index functions to that of the given basis.
 */
void
Basis::init2(const Basis &b)
{
  // Copy over the basic stuff.
  latvec           = b.latvec;
  invR             = b.invR;
  G                = b.G;
  GGT              = b.GGT;
  unit_cell_volume = b.unit_cell_volume;
  Nx               = b.Nx;
  Ny               = b.Ny;
  Nz               = b.Nz;
  NxNyNz           = b.NxNyNz;
  Gxmin            = b.Gxmin;
  Gxmax            = b.Gxmax;
  Gymin            = b.Gymin;
  Gymax            = b.Gymax;
  Gzmin            = b.Gzmin;
  Gzmax            = b.Gzmax;


  nbasis           = b.nbasis;

  // point the index arrays to those of b.
  // need to be careful when deallocate the arrays in destructors.
  Gx     =  b.Gx;
  Gy     =  b.Gy;
  Gz     =  b.Gz;
  index  =  b.index;

}
